Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Magma_pytorch
Commits
0063a668
Commit
0063a668
authored
May 13, 2025
by
chenzk
Browse files
v1.0
parents
Changes
352
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3730 additions
and
0 deletions
+3730
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/__pycache__/cotracker3_offline.cpython-310.pyc
.../cotracker/__pycache__/cotracker3_offline.cpython-310.pyc
+0
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/__pycache__/cotracker3_online.cpython-310.pyc
...e/cotracker/__pycache__/cotracker3_online.cpython-310.pyc
+0
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/blocks.py
...arch/co-tracker/cotracker/models/core/cotracker/blocks.py
+438
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/cotracker.py
...h/co-tracker/cotracker/models/core/cotracker/cotracker.py
+577
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/cotracker3_offline.py
...ker/cotracker/models/core/cotracker/cotracker3_offline.py
+233
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/cotracker3_online.py
...cker/cotracker/models/core/cotracker/cotracker3_online.py
+541
-0
facebookresearch/co-tracker/cotracker/models/core/cotracker/losses.py
...arch/co-tracker/cotracker/models/core/cotracker/losses.py
+118
-0
facebookresearch/co-tracker/cotracker/models/core/embeddings.py
...okresearch/co-tracker/cotracker/models/core/embeddings.py
+120
-0
facebookresearch/co-tracker/cotracker/models/core/model_utils.py
...kresearch/co-tracker/cotracker/models/core/model_utils.py
+426
-0
facebookresearch/co-tracker/cotracker/models/evaluation_predictor.py
...earch/co-tracker/cotracker/models/evaluation_predictor.py
+199
-0
facebookresearch/co-tracker/cotracker/predictor.py
facebookresearch/co-tracker/cotracker/predictor.py
+309
-0
facebookresearch/co-tracker/cotracker/utils/__init__.py
facebookresearch/co-tracker/cotracker/utils/__init__.py
+5
-0
facebookresearch/co-tracker/cotracker/utils/__pycache__/__init__.cpython-310.pyc
...cker/cotracker/utils/__pycache__/__init__.cpython-310.pyc
+0
-0
facebookresearch/co-tracker/cotracker/utils/__pycache__/visualizer.cpython-310.pyc
...er/cotracker/utils/__pycache__/visualizer.cpython-310.pyc
+0
-0
facebookresearch/co-tracker/cotracker/utils/train_utils.py
facebookresearch/co-tracker/cotracker/utils/train_utils.py
+255
-0
facebookresearch/co-tracker/cotracker/utils/visualizer.py
facebookresearch/co-tracker/cotracker/utils/visualizer.py
+363
-0
facebookresearch/co-tracker/cotracker/version.py
facebookresearch/co-tracker/cotracker/version.py
+8
-0
facebookresearch/co-tracker/demo.py
facebookresearch/co-tracker/demo.py
+109
-0
facebookresearch/co-tracker/docs/Makefile
facebookresearch/co-tracker/docs/Makefile
+14
-0
facebookresearch/co-tracker/docs/source/apis/models.rst
facebookresearch/co-tracker/docs/source/apis/models.rst
+15
-0
No files found.
facebookresearch/co-tracker/cotracker/models/core/cotracker/__pycache__/cotracker3_offline.cpython-310.pyc
0 → 100644
View file @
0063a668
File added
facebookresearch/co-tracker/cotracker/models/core/cotracker/__pycache__/cotracker3_online.cpython-310.pyc
0 → 100644
View file @
0063a668
File added
facebookresearch/co-tracker/cotracker/models/core/cotracker/blocks.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
typing
import
Callable
import
collections
from
torch
import
Tensor
from
itertools
import
repeat
from
cotracker.models.core.model_utils
import
bilinear_sampler
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
)
and
not
isinstance
(
x
,
str
):
return
tuple
(
x
)
return
tuple
(
repeat
(
x
,
n
))
return
parse
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
return
val
if
exists
(
val
)
else
d
to_2tuple
=
_ntuple
(
2
)
class
Mlp
(
nn
.
Module
):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
norm_layer
=
None
,
bias
=
True
,
drop
=
0.0
,
use_conv
=
False
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
bias
=
to_2tuple
(
bias
)
drop_probs
=
to_2tuple
(
drop
)
linear_layer
=
partial
(
nn
.
Conv2d
,
kernel_size
=
1
)
if
use_conv
else
nn
.
Linear
self
.
fc1
=
linear_layer
(
in_features
,
hidden_features
,
bias
=
bias
[
0
])
self
.
act
=
act_layer
()
self
.
drop1
=
nn
.
Dropout
(
drop_probs
[
0
])
self
.
norm
=
(
norm_layer
(
hidden_features
)
if
norm_layer
is
not
None
else
nn
.
Identity
()
)
self
.
fc2
=
linear_layer
(
hidden_features
,
out_features
,
bias
=
bias
[
1
])
self
.
drop2
=
nn
.
Dropout
(
drop_probs
[
1
])
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop2
(
x
)
return
x
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_planes
,
planes
,
norm_fn
=
"group"
,
stride
=
1
):
super
(
ResidualBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_planes
,
planes
,
kernel_size
=
3
,
padding
=
1
,
stride
=
stride
,
padding_mode
=
"zeros"
,
)
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
kernel_size
=
3
,
padding
=
1
,
padding_mode
=
"zeros"
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
num_groups
=
planes
//
8
if
norm_fn
==
"group"
:
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
planes
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
planes
)
if
not
stride
==
1
:
self
.
norm3
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
planes
)
elif
norm_fn
==
"batch"
:
self
.
norm1
=
nn
.
BatchNorm2d
(
planes
)
self
.
norm2
=
nn
.
BatchNorm2d
(
planes
)
if
not
stride
==
1
:
self
.
norm3
=
nn
.
BatchNorm2d
(
planes
)
elif
norm_fn
==
"instance"
:
self
.
norm1
=
nn
.
InstanceNorm2d
(
planes
)
self
.
norm2
=
nn
.
InstanceNorm2d
(
planes
)
if
not
stride
==
1
:
self
.
norm3
=
nn
.
InstanceNorm2d
(
planes
)
elif
norm_fn
==
"none"
:
self
.
norm1
=
nn
.
Sequential
()
self
.
norm2
=
nn
.
Sequential
()
if
not
stride
==
1
:
self
.
norm3
=
nn
.
Sequential
()
if
stride
==
1
:
self
.
downsample
=
None
else
:
self
.
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_planes
,
planes
,
kernel_size
=
1
,
stride
=
stride
),
self
.
norm3
)
def
forward
(
self
,
x
):
y
=
x
y
=
self
.
relu
(
self
.
norm1
(
self
.
conv1
(
y
)))
y
=
self
.
relu
(
self
.
norm2
(
self
.
conv2
(
y
)))
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
self
.
relu
(
x
+
y
)
class
BasicEncoder
(
nn
.
Module
):
def
__init__
(
self
,
input_dim
=
3
,
output_dim
=
128
,
stride
=
4
):
super
(
BasicEncoder
,
self
).
__init__
()
self
.
stride
=
stride
self
.
norm_fn
=
"instance"
self
.
in_planes
=
output_dim
//
2
self
.
norm1
=
nn
.
InstanceNorm2d
(
self
.
in_planes
)
self
.
norm2
=
nn
.
InstanceNorm2d
(
output_dim
*
2
)
self
.
conv1
=
nn
.
Conv2d
(
input_dim
,
self
.
in_planes
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
padding_mode
=
"zeros"
,
)
self
.
relu1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
layer1
=
self
.
_make_layer
(
output_dim
//
2
,
stride
=
1
)
self
.
layer2
=
self
.
_make_layer
(
output_dim
//
4
*
3
,
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
output_dim
,
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
output_dim
,
stride
=
2
)
self
.
conv2
=
nn
.
Conv2d
(
output_dim
*
3
+
output_dim
//
4
,
output_dim
*
2
,
kernel_size
=
3
,
padding
=
1
,
padding_mode
=
"zeros"
,
)
self
.
relu2
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv3
=
nn
.
Conv2d
(
output_dim
*
2
,
output_dim
,
kernel_size
=
1
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"relu"
)
elif
isinstance
(
m
,
(
nn
.
InstanceNorm2d
)):
if
m
.
weight
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
_make_layer
(
self
,
dim
,
stride
=
1
):
layer1
=
ResidualBlock
(
self
.
in_planes
,
dim
,
self
.
norm_fn
,
stride
=
stride
)
layer2
=
ResidualBlock
(
dim
,
dim
,
self
.
norm_fn
,
stride
=
1
)
layers
=
(
layer1
,
layer2
)
self
.
in_planes
=
dim
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
_
,
_
,
H
,
W
=
x
.
shape
x
=
self
.
conv1
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
relu1
(
x
)
a
=
self
.
layer1
(
x
)
b
=
self
.
layer2
(
a
)
c
=
self
.
layer3
(
b
)
d
=
self
.
layer4
(
c
)
def
_bilinear_intepolate
(
x
):
return
F
.
interpolate
(
x
,
(
H
//
self
.
stride
,
W
//
self
.
stride
),
mode
=
"bilinear"
,
align_corners
=
True
,
)
a
=
_bilinear_intepolate
(
a
)
b
=
_bilinear_intepolate
(
b
)
c
=
_bilinear_intepolate
(
c
)
d
=
_bilinear_intepolate
(
d
)
x
=
self
.
conv2
(
torch
.
cat
([
a
,
b
,
c
,
d
],
dim
=
1
))
x
=
self
.
norm2
(
x
)
x
=
self
.
relu2
(
x
)
x
=
self
.
conv3
(
x
)
return
x
class
EfficientCorrBlock
:
def
__init__
(
self
,
fmaps
,
num_levels
=
4
,
radius
=
4
,
padding_mode
=
"zeros"
,
):
B
,
S
,
C
,
H
,
W
=
fmaps
.
shape
self
.
padding_mode
=
padding_mode
self
.
num_levels
=
num_levels
self
.
radius
=
radius
self
.
fmaps_pyramid
=
[]
self
.
fmaps_pyramid
.
append
(
fmaps
)
for
i
in
range
(
self
.
num_levels
-
1
):
fmaps_
=
fmaps
.
reshape
(
B
*
S
,
C
,
H
,
W
)
fmaps_
=
F
.
avg_pool2d
(
fmaps_
,
2
,
stride
=
2
)
_
,
_
,
H
,
W
=
fmaps_
.
shape
fmaps
=
fmaps_
.
reshape
(
B
,
S
,
C
,
H
,
W
)
self
.
fmaps_pyramid
.
append
(
fmaps
)
def
sample
(
self
,
coords
,
target
):
r
=
self
.
radius
device
=
coords
.
device
B
,
S
,
N
,
D
=
coords
.
shape
assert
D
==
2
target
=
target
.
permute
(
0
,
1
,
3
,
2
).
unsqueeze
(
-
1
)
out_pyramid
=
[]
for
i
in
range
(
self
.
num_levels
):
pyramid
=
self
.
fmaps_pyramid
[
i
]
C
,
H
,
W
=
pyramid
.
shape
[
2
:]
centroid_lvl
=
(
torch
.
cat
(
[
torch
.
zeros_like
(
coords
[...,
:
1
],
device
=
device
),
coords
],
dim
=-
1
).
reshape
(
B
*
S
,
N
,
1
,
1
,
3
)
/
2
**
i
)
dx
=
torch
.
linspace
(
-
r
,
r
,
2
*
r
+
1
,
device
=
device
)
dy
=
torch
.
linspace
(
-
r
,
r
,
2
*
r
+
1
,
device
=
device
)
xgrid
,
ygrid
=
torch
.
meshgrid
(
dy
,
dx
,
indexing
=
"ij"
)
zgrid
=
torch
.
zeros_like
(
xgrid
,
device
=
device
)
delta
=
torch
.
stack
([
zgrid
,
xgrid
,
ygrid
],
axis
=-
1
)
delta_lvl
=
delta
.
view
(
1
,
1
,
2
*
r
+
1
,
2
*
r
+
1
,
3
)
coords_lvl
=
centroid_lvl
+
delta_lvl
pyramid_sample
=
bilinear_sampler
(
pyramid
.
reshape
(
B
*
S
,
C
,
1
,
H
,
W
),
coords_lvl
)
corr
=
torch
.
sum
(
target
*
pyramid_sample
.
reshape
(
B
,
S
,
C
,
N
,
-
1
),
dim
=
2
)
corr
=
corr
/
torch
.
sqrt
(
torch
.
tensor
(
C
).
float
())
out_pyramid
.
append
(
corr
)
out
=
torch
.
cat
(
out_pyramid
,
dim
=-
1
)
# B, S, N, LRR*2
out
=
out
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
B
*
N
,
S
,
-
1
).
float
()
return
out
class
CorrBlock
:
def
__init__
(
self
,
fmaps
,
num_levels
=
4
,
radius
=
4
,
multiple_track_feats
=
False
,
padding_mode
=
"zeros"
,
):
B
,
S
,
C
,
H
,
W
=
fmaps
.
shape
self
.
S
,
self
.
C
,
self
.
H
,
self
.
W
=
S
,
C
,
H
,
W
self
.
padding_mode
=
padding_mode
self
.
num_levels
=
num_levels
self
.
radius
=
radius
self
.
fmaps_pyramid
=
[]
self
.
multiple_track_feats
=
multiple_track_feats
self
.
fmaps_pyramid
.
append
(
fmaps
)
for
i
in
range
(
self
.
num_levels
-
1
):
fmaps_
=
fmaps
.
reshape
(
B
*
S
,
C
,
H
,
W
)
fmaps_
=
F
.
avg_pool2d
(
fmaps_
,
2
,
stride
=
2
)
_
,
_
,
H
,
W
=
fmaps_
.
shape
fmaps
=
fmaps_
.
reshape
(
B
,
S
,
C
,
H
,
W
)
self
.
fmaps_pyramid
.
append
(
fmaps
)
def
sample
(
self
,
coords
):
r
=
self
.
radius
B
,
S
,
N
,
D
=
coords
.
shape
assert
D
==
2
H
,
W
=
self
.
H
,
self
.
W
out_pyramid
=
[]
for
i
in
range
(
self
.
num_levels
):
corrs
=
self
.
corrs_pyramid
[
i
]
# B, S, N, H, W
*
_
,
H
,
W
=
corrs
.
shape
dx
=
torch
.
linspace
(
-
r
,
r
,
2
*
r
+
1
)
dy
=
torch
.
linspace
(
-
r
,
r
,
2
*
r
+
1
)
delta
=
torch
.
stack
(
torch
.
meshgrid
(
dy
,
dx
,
indexing
=
"ij"
),
axis
=-
1
).
to
(
coords
.
device
)
centroid_lvl
=
coords
.
reshape
(
B
*
S
*
N
,
1
,
1
,
2
)
/
2
**
i
delta_lvl
=
delta
.
view
(
1
,
2
*
r
+
1
,
2
*
r
+
1
,
2
)
coords_lvl
=
centroid_lvl
+
delta_lvl
corrs
=
bilinear_sampler
(
corrs
.
reshape
(
B
*
S
*
N
,
1
,
H
,
W
),
coords_lvl
,
padding_mode
=
self
.
padding_mode
,
)
corrs
=
corrs
.
view
(
B
,
S
,
N
,
-
1
)
out_pyramid
.
append
(
corrs
)
out
=
torch
.
cat
(
out_pyramid
,
dim
=-
1
)
# B, S, N, LRR*2
out
=
out
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
B
*
N
,
S
,
-
1
).
float
()
return
out
def
corr
(
self
,
targets
):
B
,
S
,
N
,
C
=
targets
.
shape
if
self
.
multiple_track_feats
:
targets_split
=
targets
.
split
(
C
//
self
.
num_levels
,
dim
=-
1
)
B
,
S
,
N
,
C
=
targets_split
[
0
].
shape
assert
C
==
self
.
C
assert
S
==
self
.
S
fmap1
=
targets
self
.
corrs_pyramid
=
[]
for
i
,
fmaps
in
enumerate
(
self
.
fmaps_pyramid
):
*
_
,
H
,
W
=
fmaps
.
shape
fmap2s
=
fmaps
.
view
(
B
,
S
,
C
,
H
*
W
)
# B S C H W -> B S C (H W)
if
self
.
multiple_track_feats
:
fmap1
=
targets_split
[
i
]
corrs
=
torch
.
matmul
(
fmap1
,
fmap2s
)
corrs
=
corrs
.
view
(
B
,
S
,
N
,
H
,
W
)
# B S N (H W) -> B S N H W
corrs
=
corrs
/
torch
.
sqrt
(
torch
.
tensor
(
C
).
float
())
self
.
corrs_pyramid
.
append
(
corrs
)
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
query_dim
,
context_dim
=
None
,
num_heads
=
8
,
dim_head
=
48
,
qkv_bias
=
False
):
super
().
__init__
()
inner_dim
=
dim_head
*
num_heads
context_dim
=
default
(
context_dim
,
query_dim
)
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
num_heads
self
.
to_q
=
nn
.
Linear
(
query_dim
,
inner_dim
,
bias
=
qkv_bias
)
self
.
to_kv
=
nn
.
Linear
(
context_dim
,
inner_dim
*
2
,
bias
=
qkv_bias
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
query_dim
)
def
forward
(
self
,
x
,
context
=
None
,
attn_bias
=
None
):
B
,
N1
,
C
=
x
.
shape
h
=
self
.
heads
q
=
self
.
to_q
(
x
).
reshape
(
B
,
N1
,
h
,
C
//
h
).
permute
(
0
,
2
,
1
,
3
)
context
=
default
(
context
,
x
)
k
,
v
=
self
.
to_kv
(
context
).
chunk
(
2
,
dim
=-
1
)
N2
=
context
.
shape
[
1
]
k
=
k
.
reshape
(
B
,
N2
,
h
,
C
//
h
).
permute
(
0
,
2
,
1
,
3
)
v
=
v
.
reshape
(
B
,
N2
,
h
,
C
//
h
).
permute
(
0
,
2
,
1
,
3
)
sim
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
if
attn_bias
is
not
None
:
sim
=
sim
+
attn_bias
attn
=
sim
.
softmax
(
dim
=-
1
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N1
,
C
)
return
self
.
to_out
(
x
)
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
attn_class
:
Callable
[...,
nn
.
Module
]
=
Attention
,
mlp_ratio
=
4.0
,
**
block_kwargs
):
super
().
__init__
()
self
.
norm1
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
attn
=
attn_class
(
hidden_size
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
**
block_kwargs
)
self
.
norm2
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
approx_gelu
=
lambda
:
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
approx_gelu
,
drop
=
0
,
)
def
forward
(
self
,
x
,
mask
=
None
):
attn_bias
=
mask
if
mask
is
not
None
:
mask
=
(
(
mask
[:,
None
]
*
mask
[:,
:,
None
])
.
unsqueeze
(
1
)
.
expand
(
-
1
,
self
.
attn
.
num_heads
,
-
1
,
-
1
)
)
max_neg_value
=
-
torch
.
finfo
(
x
.
dtype
).
max
attn_bias
=
(
~
mask
)
*
max_neg_value
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
attn_bias
=
attn_bias
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
facebookresearch/co-tracker/cotracker/models/core/cotracker/cotracker.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
cotracker.models.core.model_utils
import
sample_features4d
,
sample_features5d
from
cotracker.models.core.embeddings
import
(
get_2d_embedding
,
get_1d_sincos_pos_embed_from_grid
,
get_2d_sincos_pos_embed
,
)
from
cotracker.models.core.cotracker.blocks
import
(
Mlp
,
BasicEncoder
,
AttnBlock
,
CorrBlock
,
Attention
,
)
torch
.
manual_seed
(
0
)
class
CoTracker2
(
nn
.
Module
):
def
__init__
(
self
,
window_len
=
8
,
stride
=
4
,
add_space_attn
=
True
,
num_virtual_tracks
=
64
,
model_resolution
=
(
384
,
512
),
):
super
(
CoTracker2
,
self
).
__init__
()
self
.
window_len
=
window_len
self
.
stride
=
stride
self
.
hidden_dim
=
256
self
.
latent_dim
=
128
self
.
add_space_attn
=
add_space_attn
self
.
fnet
=
BasicEncoder
(
output_dim
=
self
.
latent_dim
)
self
.
num_virtual_tracks
=
num_virtual_tracks
self
.
model_resolution
=
model_resolution
self
.
input_dim
=
456
self
.
updateformer
=
EfficientUpdateFormer
(
space_depth
=
6
,
time_depth
=
6
,
input_dim
=
self
.
input_dim
,
hidden_size
=
384
,
output_dim
=
self
.
latent_dim
+
2
,
mlp_ratio
=
4.0
,
add_space_attn
=
add_space_attn
,
num_virtual_tracks
=
num_virtual_tracks
,
)
time_grid
=
torch
.
linspace
(
0
,
window_len
-
1
,
window_len
).
reshape
(
1
,
window_len
,
1
)
self
.
register_buffer
(
"time_emb"
,
get_1d_sincos_pos_embed_from_grid
(
self
.
input_dim
,
time_grid
[
0
])
)
self
.
register_buffer
(
"pos_emb"
,
get_2d_sincos_pos_embed
(
embed_dim
=
self
.
input_dim
,
grid_size
=
(
model_resolution
[
0
]
//
stride
,
model_resolution
[
1
]
//
stride
,
),
),
)
self
.
norm
=
nn
.
GroupNorm
(
1
,
self
.
latent_dim
)
self
.
track_feat_updater
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
latent_dim
,
self
.
latent_dim
),
nn
.
GELU
(),
)
self
.
vis_predictor
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
latent_dim
,
1
),
)
def
forward_window
(
self
,
fmaps
,
coords
,
track_feat
=
None
,
vis
=
None
,
track_mask
=
None
,
attention_mask
=
None
,
iters
=
4
,
):
# B = batch size
# S = number of frames in the window)
# N = number of tracks
# C = channels of a point feature vector
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# track_feat = B S N C
# vis = B S N 1
# track_mask = B S N 1
# attention_mask = B S N
B
,
S_init
,
N
,
__
=
track_mask
.
shape
B
,
S
,
*
_
=
fmaps
.
shape
track_mask
=
F
.
pad
(
track_mask
,
(
0
,
0
,
0
,
0
,
0
,
S
-
S_init
),
"constant"
)
track_mask_vis
=
(
torch
.
cat
([
track_mask
,
vis
],
dim
=-
1
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
B
*
N
,
S
,
2
)
)
corr_block
=
CorrBlock
(
fmaps
,
num_levels
=
4
,
radius
=
3
,
padding_mode
=
"border"
,
)
sampled_pos_emb
=
(
sample_features4d
(
self
.
pos_emb
.
repeat
(
B
,
1
,
1
,
1
),
coords
[:,
0
])
.
reshape
(
B
*
N
,
self
.
input_dim
)
.
unsqueeze
(
1
)
)
# B E N -> (B N) 1 E
coord_preds
=
[]
for
__
in
range
(
iters
):
coords
=
coords
.
detach
()
# B S N 2
corr_block
.
corr
(
track_feat
)
# Sample correlation features around each point
fcorrs
=
corr_block
.
sample
(
coords
)
# (B N) S LRR
# Get the flow embeddings
flows
=
(
coords
-
coords
[:,
0
:
1
]).
permute
(
0
,
2
,
1
,
3
).
reshape
(
B
*
N
,
S
,
2
)
flow_emb
=
get_2d_embedding
(
flows
,
64
,
cat_coords
=
True
)
# N S E
track_feat_
=
track_feat
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
B
*
N
,
S
,
self
.
latent_dim
)
transformer_input
=
torch
.
cat
(
[
flow_emb
,
fcorrs
,
track_feat_
,
track_mask_vis
],
dim
=
2
)
x
=
transformer_input
+
sampled_pos_emb
+
self
.
time_emb
x
=
x
.
view
(
B
,
N
,
S
,
-
1
)
# (B N) S D -> B N S D
delta
=
self
.
updateformer
(
x
,
attention_mask
.
reshape
(
B
*
S
,
N
),
# B S N -> (B S) N
)
delta_coords
=
delta
[...,
:
2
].
permute
(
0
,
2
,
1
,
3
)
coords
=
coords
+
delta_coords
coord_preds
.
append
(
coords
*
self
.
stride
)
delta_feats_
=
delta
[...,
2
:].
reshape
(
B
*
N
*
S
,
self
.
latent_dim
)
track_feat_
=
track_feat
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
B
*
N
*
S
,
self
.
latent_dim
)
track_feat_
=
self
.
track_feat_updater
(
self
.
norm
(
delta_feats_
))
+
track_feat_
track_feat
=
track_feat_
.
reshape
(
B
,
N
,
S
,
self
.
latent_dim
).
permute
(
0
,
2
,
1
,
3
)
# (B N S) C -> B S N C
vis_pred
=
self
.
vis_predictor
(
track_feat
).
reshape
(
B
,
S
,
N
)
return
coord_preds
,
vis_pred
def
get_track_feat
(
self
,
fmaps
,
queried_frames
,
queried_coords
):
sample_frames
=
queried_frames
[:,
None
,
:,
None
]
sample_coords
=
torch
.
cat
(
[
sample_frames
,
queried_coords
[:,
None
],
],
dim
=-
1
,
)
sample_track_feats
=
sample_features5d
(
fmaps
,
sample_coords
)
return
sample_track_feats
def
init_video_online_processing
(
self
):
self
.
online_ind
=
0
self
.
online_track_feat
=
None
self
.
online_coords_predicted
=
None
self
.
online_vis_predicted
=
None
def
forward
(
self
,
video
,
queries
,
iters
=
4
,
is_train
=
False
,
is_online
=
False
):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B
,
T
,
C
,
H
,
W
=
video
.
shape
B
,
N
,
__
=
queries
.
shape
S
=
self
.
window_len
device
=
queries
.
device
# B = batch size
# S = number of frames in the window of the padded video
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# video = B T C H W
# queries = B N 3
# coords_init = B S N 2
# vis_init = B S N 1
assert
S
>=
2
# A tracker needs at least two frames to track something
if
is_online
:
assert
T
<=
S
,
"Online mode: video chunk must be <= window size."
assert
(
self
.
online_ind
is
not
None
),
"Call model.init_video_online_processing() first."
assert
not
is_train
,
"Training not supported in online mode."
step
=
S
//
2
# How much the sliding window moves at every step
video
=
2
*
(
video
/
255.0
)
-
1.0
# The first channel is the frame number
# The rest are the coordinates of points we want to track
queried_frames
=
queries
[:,
:,
0
].
long
()
queried_coords
=
queries
[...,
1
:]
queried_coords
=
queried_coords
/
self
.
stride
# We store our predictions here
coords_predicted
=
torch
.
zeros
((
B
,
T
,
N
,
2
),
device
=
device
)
vis_predicted
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
)
if
is_online
:
if
self
.
online_coords_predicted
is
None
:
# Init online predictions with zeros
self
.
online_coords_predicted
=
coords_predicted
self
.
online_vis_predicted
=
vis_predicted
else
:
# Pad online predictions with zeros for the current window
pad
=
min
(
step
,
T
-
step
)
coords_predicted
=
F
.
pad
(
self
.
online_coords_predicted
,
(
0
,
0
,
0
,
0
,
0
,
pad
),
"constant"
)
vis_predicted
=
F
.
pad
(
self
.
online_vis_predicted
,
(
0
,
0
,
0
,
pad
),
"constant"
)
all_coords_predictions
,
all_vis_predictions
=
[],
[]
# Pad the video so that an integer number of sliding windows fit into it
# TODO: we may drop this requirement because the transformer should not care
# TODO: pad the features instead of the video
pad
=
(
S
-
T
if
is_online
else
(
S
-
T
%
S
)
%
S
)
# We don't want to pad if T % S == 0
video
=
video
.
reshape
(
B
,
1
,
T
,
C
*
H
*
W
)
video_pad
=
video
[:,
:,
-
1
:].
repeat
(
1
,
1
,
pad
,
1
)
video
=
torch
.
cat
([
video
,
video_pad
],
dim
=
2
)
# Compute convolutional features for the video or for the current chunk in case of online mode
fmaps
=
self
.
fnet
(
video
.
reshape
(
-
1
,
C
,
H
,
W
)).
reshape
(
B
,
-
1
,
self
.
latent_dim
,
H
//
self
.
stride
,
W
//
self
.
stride
)
# We compute track features
track_feat
=
self
.
get_track_feat
(
fmaps
,
queried_frames
-
self
.
online_ind
if
is_online
else
queried_frames
,
queried_coords
,
).
repeat
(
1
,
S
,
1
,
1
)
if
is_online
:
# We update track features for the current window
sample_frames
=
queried_frames
[:,
None
,
:,
None
]
# B 1 N 1
left
=
0
if
self
.
online_ind
==
0
else
self
.
online_ind
+
step
right
=
self
.
online_ind
+
S
sample_mask
=
(
sample_frames
>=
left
)
&
(
sample_frames
<
right
)
if
self
.
online_track_feat
is
None
:
self
.
online_track_feat
=
torch
.
zeros_like
(
track_feat
,
device
=
device
)
self
.
online_track_feat
+=
track_feat
*
sample_mask
track_feat
=
self
.
online_track_feat
.
clone
()
# We process ((num_windows - 1) * step + S) frames in total, so there are
# (ceil((T - S) / step) + 1) windows
num_windows
=
(
T
-
S
+
step
-
1
)
//
step
+
1
# We process only the current video chunk in the online mode
indices
=
[
self
.
online_ind
]
if
is_online
else
range
(
0
,
step
*
num_windows
,
step
)
coords_init
=
queried_coords
.
reshape
(
B
,
1
,
N
,
2
).
expand
(
B
,
S
,
N
,
2
).
float
()
vis_init
=
torch
.
ones
((
B
,
S
,
N
,
1
),
device
=
device
).
float
()
*
10
for
ind
in
indices
:
# We copy over coords and vis for tracks that are queried
# by the end of the previous window, which is ind + overlap
if
ind
>
0
:
overlap
=
S
-
step
copy_over
=
(
queried_frames
<
ind
+
overlap
)[
:,
None
,
:,
None
]
# B 1 N 1
coords_prev
=
torch
.
nn
.
functional
.
pad
(
coords_predicted
[:,
ind
:
ind
+
overlap
]
/
self
.
stride
,
(
0
,
0
,
0
,
0
,
0
,
step
),
"replicate"
,
)
# B S N 2
vis_prev
=
torch
.
nn
.
functional
.
pad
(
vis_predicted
[:,
ind
:
ind
+
overlap
,
:,
None
].
clone
(),
(
0
,
0
,
0
,
0
,
0
,
step
),
"replicate"
,
)
# B S N 1
coords_init
=
torch
.
where
(
copy_over
.
expand_as
(
coords_init
),
coords_prev
,
coords_init
)
vis_init
=
torch
.
where
(
copy_over
.
expand_as
(
vis_init
),
vis_prev
,
vis_init
)
# The attention mask is 1 for the spatio-temporal points within
# a track which is updated in the current window
attention_mask
=
(
(
queried_frames
<
ind
+
S
).
reshape
(
B
,
1
,
N
).
repeat
(
1
,
S
,
1
)
)
# B S N
# The track mask is 1 for the spatio-temporal points that actually
# need updating: only after begin queried, and not if contained
# in a previous window
track_mask
=
(
queried_frames
[:,
None
,
:,
None
]
<=
torch
.
arange
(
ind
,
ind
+
S
,
device
=
device
)[
None
,
:,
None
,
None
]
).
contiguous
()
# B S N 1
if
ind
>
0
:
track_mask
[:,
:
overlap
,
:,
:]
=
False
# Predict the coordinates and visibility for the current window
coords
,
vis
=
self
.
forward_window
(
fmaps
=
fmaps
if
is_online
else
fmaps
[:,
ind
:
ind
+
S
],
coords
=
coords_init
,
track_feat
=
attention_mask
.
unsqueeze
(
-
1
)
*
track_feat
,
vis
=
vis_init
,
track_mask
=
track_mask
,
attention_mask
=
attention_mask
,
iters
=
iters
,
)
S_trimmed
=
(
T
if
is_online
else
min
(
T
-
ind
,
S
)
)
# accounts for last window duration
coords_predicted
[:,
ind
:
ind
+
S
]
=
coords
[
-
1
][:,
:
S_trimmed
]
vis_predicted
[:,
ind
:
ind
+
S
]
=
vis
[:,
:
S_trimmed
]
if
is_train
:
all_coords_predictions
.
append
(
[
coord
[:,
:
S_trimmed
]
for
coord
in
coords
]
)
all_vis_predictions
.
append
(
torch
.
sigmoid
(
vis
[:,
:
S_trimmed
]))
if
is_online
:
self
.
online_ind
+=
step
self
.
online_coords_predicted
=
coords_predicted
self
.
online_vis_predicted
=
vis_predicted
vis_predicted
=
torch
.
sigmoid
(
vis_predicted
)
if
is_train
:
mask
=
(
queried_frames
[:,
None
]
<=
torch
.
arange
(
0
,
T
,
device
=
device
)[
None
,
:,
None
]
)
train_data
=
(
all_coords_predictions
,
all_vis_predictions
,
mask
)
else
:
train_data
=
None
return
coords_predicted
,
vis_predicted
,
train_data
class
EfficientUpdateFormer
(
nn
.
Module
):
"""
Transformer model that updates track estimates.
"""
def
__init__
(
self
,
space_depth
=
6
,
time_depth
=
6
,
input_dim
=
320
,
hidden_size
=
384
,
num_heads
=
8
,
output_dim
=
130
,
mlp_ratio
=
4.0
,
num_virtual_tracks
=
64
,
add_space_attn
=
True
,
linear_layer_for_vis_conf
=
False
,
):
super
().
__init__
()
self
.
out_channels
=
2
self
.
num_heads
=
num_heads
self
.
hidden_size
=
hidden_size
self
.
input_transform
=
torch
.
nn
.
Linear
(
input_dim
,
hidden_size
,
bias
=
True
)
if
linear_layer_for_vis_conf
:
self
.
flow_head
=
torch
.
nn
.
Linear
(
hidden_size
,
output_dim
-
2
,
bias
=
True
)
self
.
vis_conf_head
=
torch
.
nn
.
Linear
(
hidden_size
,
2
,
bias
=
True
)
else
:
self
.
flow_head
=
torch
.
nn
.
Linear
(
hidden_size
,
output_dim
,
bias
=
True
)
self
.
num_virtual_tracks
=
num_virtual_tracks
self
.
virual_tracks
=
nn
.
Parameter
(
torch
.
randn
(
1
,
num_virtual_tracks
,
1
,
hidden_size
)
)
self
.
add_space_attn
=
add_space_attn
self
.
linear_layer_for_vis_conf
=
linear_layer_for_vis_conf
self
.
time_blocks
=
nn
.
ModuleList
(
[
AttnBlock
(
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
,
attn_class
=
Attention
,
)
for
_
in
range
(
time_depth
)
]
)
if
add_space_attn
:
self
.
space_virtual_blocks
=
nn
.
ModuleList
(
[
AttnBlock
(
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
,
attn_class
=
Attention
,
)
for
_
in
range
(
space_depth
)
]
)
self
.
space_point2virtual_blocks
=
nn
.
ModuleList
(
[
CrossAttnBlock
(
hidden_size
,
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
)
for
_
in
range
(
space_depth
)
]
)
self
.
space_virtual2point_blocks
=
nn
.
ModuleList
(
[
CrossAttnBlock
(
hidden_size
,
hidden_size
,
num_heads
,
mlp_ratio
=
mlp_ratio
)
for
_
in
range
(
space_depth
)
]
)
assert
len
(
self
.
time_blocks
)
>=
len
(
self
.
space_virtual2point_blocks
)
self
.
initialize_weights
()
def
initialize_weights
(
self
):
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
torch
.
nn
.
init
.
trunc_normal_
(
self
.
flow_head
.
weight
,
std
=
0.001
)
if
self
.
linear_layer_for_vis_conf
:
torch
.
nn
.
init
.
trunc_normal_
(
self
.
vis_conf_head
.
weight
,
std
=
0.001
)
def
_trunc_init
(
module
):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if
isinstance
(
module
,
nn
.
Linear
):
torch
.
nn
.
init
.
trunc_normal_
(
module
.
weight
,
std
=
0.02
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
self
.
apply
(
_basic_init
)
def
forward
(
self
,
input_tensor
,
mask
=
None
,
add_space_attn
=
True
):
tokens
=
self
.
input_transform
(
input_tensor
)
B
,
_
,
T
,
_
=
tokens
.
shape
virtual_tokens
=
self
.
virual_tracks
.
repeat
(
B
,
1
,
T
,
1
)
tokens
=
torch
.
cat
([
tokens
,
virtual_tokens
],
dim
=
1
)
_
,
N
,
_
,
_
=
tokens
.
shape
j
=
0
layers
=
[]
for
i
in
range
(
len
(
self
.
time_blocks
)):
time_tokens
=
tokens
.
contiguous
().
view
(
B
*
N
,
T
,
-
1
)
# B N T C -> (B N) T C
time_tokens
=
self
.
time_blocks
[
i
](
time_tokens
)
tokens
=
time_tokens
.
view
(
B
,
N
,
T
,
-
1
)
# (B N) T C -> B N T C
if
(
add_space_attn
and
hasattr
(
self
,
"space_virtual_blocks"
)
and
(
i
%
(
len
(
self
.
time_blocks
)
//
len
(
self
.
space_virtual_blocks
))
==
0
)
):
space_tokens
=
(
tokens
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
(
B
*
T
,
N
,
-
1
)
)
# B N T C -> (B T) N C
point_tokens
=
space_tokens
[:,
:
N
-
self
.
num_virtual_tracks
]
virtual_tokens
=
space_tokens
[:,
N
-
self
.
num_virtual_tracks
:]
virtual_tokens
=
self
.
space_virtual2point_blocks
[
j
](
virtual_tokens
,
point_tokens
,
mask
=
mask
)
virtual_tokens
=
self
.
space_virtual_blocks
[
j
](
virtual_tokens
)
point_tokens
=
self
.
space_point2virtual_blocks
[
j
](
point_tokens
,
virtual_tokens
,
mask
=
mask
)
space_tokens
=
torch
.
cat
([
point_tokens
,
virtual_tokens
],
dim
=
1
)
tokens
=
space_tokens
.
view
(
B
,
T
,
N
,
-
1
).
permute
(
0
,
2
,
1
,
3
)
# (B T) N C -> B N T C
j
+=
1
tokens
=
tokens
[:,
:
N
-
self
.
num_virtual_tracks
]
flow
=
self
.
flow_head
(
tokens
)
if
self
.
linear_layer_for_vis_conf
:
vis_conf
=
self
.
vis_conf_head
(
tokens
)
flow
=
torch
.
cat
([
flow
,
vis_conf
],
dim
=-
1
)
return
flow
class
CrossAttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
context_dim
,
num_heads
=
1
,
mlp_ratio
=
4.0
,
**
block_kwargs
):
super
().
__init__
()
self
.
norm1
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
norm_context
=
nn
.
LayerNorm
(
hidden_size
)
self
.
cross_attn
=
Attention
(
hidden_size
,
context_dim
=
context_dim
,
num_heads
=
num_heads
,
qkv_bias
=
True
,
**
block_kwargs
)
self
.
norm2
=
nn
.
LayerNorm
(
hidden_size
,
elementwise_affine
=
False
,
eps
=
1e-6
)
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
approx_gelu
=
lambda
:
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
mlp
=
Mlp
(
in_features
=
hidden_size
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
approx_gelu
,
drop
=
0
,
)
def
forward
(
self
,
x
,
context
,
mask
=
None
):
attn_bias
=
None
if
mask
is
not
None
:
if
mask
.
shape
[
1
]
==
x
.
shape
[
1
]:
mask
=
mask
[:,
None
,
:,
None
].
expand
(
-
1
,
self
.
cross_attn
.
heads
,
-
1
,
context
.
shape
[
1
]
)
else
:
mask
=
mask
[:,
None
,
None
].
expand
(
-
1
,
self
.
cross_attn
.
heads
,
x
.
shape
[
1
],
-
1
)
max_neg_value
=
-
torch
.
finfo
(
x
.
dtype
).
max
attn_bias
=
(
~
mask
)
*
max_neg_value
x
=
x
+
self
.
cross_attn
(
self
.
norm1
(
x
),
context
=
self
.
norm_context
(
context
),
attn_bias
=
attn_bias
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
facebookresearch/co-tracker/cotracker/models/core/cotracker/cotracker3_offline.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
cotracker.models.core.cotracker.cotracker3_online
import
CoTrackerThreeBase
,
posenc
torch
.
manual_seed
(
0
)
class
CoTrackerThreeOffline
(
CoTrackerThreeBase
):
def
__init__
(
self
,
**
args
):
super
(
CoTrackerThreeOffline
,
self
).
__init__
(
**
args
)
def
forward
(
self
,
video
,
queries
,
iters
=
4
,
is_train
=
False
,
add_space_attn
=
True
,
fmaps_chunk_size
=
200
,
):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B
,
T
,
C
,
H
,
W
=
video
.
shape
device
=
queries
.
device
assert
H
%
self
.
stride
==
0
and
W
%
self
.
stride
==
0
B
,
N
,
__
=
queries
.
shape
# B = batch size
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# video = B T C H W
# queries = B N 3
# coords_init = B T N 2
# vis_init = B T N 1
assert
T
>=
1
# A tracker needs at least two frames to track something
video
=
2
*
(
video
/
255.0
)
-
1.0
dtype
=
video
.
dtype
queried_frames
=
queries
[:,
:,
0
].
long
()
queried_coords
=
queries
[...,
1
:
3
]
queried_coords
=
queried_coords
/
self
.
stride
# We store our predictions here
all_coords_predictions
,
all_vis_predictions
,
all_confidence_predictions
=
(
[],
[],
[],
)
C_
=
C
H4
,
W4
=
H
//
self
.
stride
,
W
//
self
.
stride
# Compute convolutional features for the video or for the current chunk in case of online mode
if
T
>
fmaps_chunk_size
:
fmaps
=
[]
for
t
in
range
(
0
,
T
,
fmaps_chunk_size
):
video_chunk
=
video
[:,
t
:
t
+
fmaps_chunk_size
]
fmaps_chunk
=
self
.
fnet
(
video_chunk
.
reshape
(
-
1
,
C_
,
H
,
W
))
T_chunk
=
video_chunk
.
shape
[
1
]
C_chunk
,
H_chunk
,
W_chunk
=
fmaps_chunk
.
shape
[
1
:]
fmaps
.
append
(
fmaps_chunk
.
reshape
(
B
,
T_chunk
,
C_chunk
,
H_chunk
,
W_chunk
))
fmaps
=
torch
.
cat
(
fmaps
,
dim
=
1
).
reshape
(
-
1
,
C_chunk
,
H_chunk
,
W_chunk
)
else
:
fmaps
=
self
.
fnet
(
video
.
reshape
(
-
1
,
C_
,
H
,
W
))
fmaps
=
fmaps
.
permute
(
0
,
2
,
3
,
1
)
fmaps
=
fmaps
/
torch
.
sqrt
(
torch
.
maximum
(
torch
.
sum
(
torch
.
square
(
fmaps
),
axis
=-
1
,
keepdims
=
True
),
torch
.
tensor
(
1e-12
,
device
=
fmaps
.
device
),
)
)
fmaps
=
fmaps
.
permute
(
0
,
3
,
1
,
2
).
reshape
(
B
,
-
1
,
self
.
latent_dim
,
H
//
self
.
stride
,
W
//
self
.
stride
)
fmaps
=
fmaps
.
to
(
dtype
)
# We compute track features
fmaps_pyramid
=
[]
track_feat_pyramid
=
[]
track_feat_support_pyramid
=
[]
fmaps_pyramid
.
append
(
fmaps
)
for
i
in
range
(
self
.
corr_levels
-
1
):
fmaps_
=
fmaps
.
reshape
(
B
*
T
,
self
.
latent_dim
,
fmaps
.
shape
[
-
2
],
fmaps
.
shape
[
-
1
]
)
fmaps_
=
F
.
avg_pool2d
(
fmaps_
,
2
,
stride
=
2
)
fmaps
=
fmaps_
.
reshape
(
B
,
T
,
self
.
latent_dim
,
fmaps_
.
shape
[
-
2
],
fmaps_
.
shape
[
-
1
]
)
fmaps_pyramid
.
append
(
fmaps
)
for
i
in
range
(
self
.
corr_levels
):
track_feat
,
track_feat_support
=
self
.
get_track_feat
(
fmaps_pyramid
[
i
],
queried_frames
,
queried_coords
/
2
**
i
,
support_radius
=
self
.
corr_radius
,
)
track_feat_pyramid
.
append
(
track_feat
.
repeat
(
1
,
T
,
1
,
1
))
track_feat_support_pyramid
.
append
(
track_feat_support
.
unsqueeze
(
1
))
D_coords
=
2
coord_preds
,
vis_preds
,
confidence_preds
=
[],
[],
[]
vis
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
).
float
()
confidence
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
).
float
()
coords
=
queried_coords
.
reshape
(
B
,
1
,
N
,
2
).
expand
(
B
,
T
,
N
,
2
).
float
()
r
=
2
*
self
.
corr_radius
+
1
for
it
in
range
(
iters
):
coords
=
coords
.
detach
()
# B T N 2
coords_init
=
coords
.
view
(
B
*
T
,
N
,
2
)
corr_embs
=
[]
corr_feats
=
[]
for
i
in
range
(
self
.
corr_levels
):
corr_feat
=
self
.
get_correlation_feat
(
fmaps_pyramid
[
i
],
coords_init
/
2
**
i
)
track_feat_support
=
(
track_feat_support_pyramid
[
i
]
.
view
(
B
,
1
,
r
,
r
,
N
,
self
.
latent_dim
)
.
squeeze
(
1
)
.
permute
(
0
,
3
,
1
,
2
,
4
)
)
corr_volume
=
torch
.
einsum
(
"btnhwc,bnijc->btnhwij"
,
corr_feat
,
track_feat_support
)
corr_emb
=
self
.
corr_mlp
(
corr_volume
.
reshape
(
B
*
T
*
N
,
r
*
r
*
r
*
r
))
corr_embs
.
append
(
corr_emb
)
corr_embs
=
torch
.
cat
(
corr_embs
,
dim
=-
1
)
corr_embs
=
corr_embs
.
view
(
B
,
T
,
N
,
corr_embs
.
shape
[
-
1
])
transformer_input
=
[
vis
[...,
None
],
confidence
[...,
None
],
corr_embs
]
rel_coords_forward
=
coords
[:,
:
-
1
]
-
coords
[:,
1
:]
rel_coords_backward
=
coords
[:,
1
:]
-
coords
[:,
:
-
1
]
rel_coords_forward
=
torch
.
nn
.
functional
.
pad
(
rel_coords_forward
,
(
0
,
0
,
0
,
0
,
0
,
1
)
)
rel_coords_backward
=
torch
.
nn
.
functional
.
pad
(
rel_coords_backward
,
(
0
,
0
,
0
,
0
,
1
,
0
)
)
scale
=
(
torch
.
tensor
(
[
self
.
model_resolution
[
1
],
self
.
model_resolution
[
0
]],
device
=
coords
.
device
,
)
/
self
.
stride
)
rel_coords_forward
=
rel_coords_forward
/
scale
rel_coords_backward
=
rel_coords_backward
/
scale
rel_pos_emb_input
=
posenc
(
torch
.
cat
([
rel_coords_forward
,
rel_coords_backward
],
dim
=-
1
),
min_deg
=
0
,
max_deg
=
10
,
)
# batch, num_points, num_frames, 84
transformer_input
.
append
(
rel_pos_emb_input
)
x
=
(
torch
.
cat
(
transformer_input
,
dim
=-
1
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
B
*
N
,
T
,
-
1
)
)
x
=
x
+
self
.
interpolate_time_embed
(
x
,
T
)
x
=
x
.
view
(
B
,
N
,
T
,
-
1
)
# (B N) T D -> B N T D
delta
=
self
.
updateformer
(
x
,
add_space_attn
=
add_space_attn
,
)
delta_coords
=
delta
[...,
:
D_coords
].
permute
(
0
,
2
,
1
,
3
)
delta_vis
=
delta
[...,
D_coords
].
permute
(
0
,
2
,
1
)
delta_confidence
=
delta
[...,
D_coords
+
1
].
permute
(
0
,
2
,
1
)
vis
=
vis
+
delta_vis
confidence
=
confidence
+
delta_confidence
coords
=
coords
+
delta_coords
coords_append
=
coords
.
clone
()
coords_append
[...,
:
2
]
=
coords_append
[...,
:
2
]
*
float
(
self
.
stride
)
coord_preds
.
append
(
coords_append
)
vis_preds
.
append
(
torch
.
sigmoid
(
vis
))
confidence_preds
.
append
(
torch
.
sigmoid
(
confidence
))
if
is_train
:
all_coords_predictions
.
append
([
coord
[...,
:
2
]
for
coord
in
coord_preds
])
all_vis_predictions
.
append
(
vis_preds
)
all_confidence_predictions
.
append
(
confidence_preds
)
if
is_train
:
train_data
=
(
all_coords_predictions
,
all_vis_predictions
,
all_confidence_predictions
,
torch
.
ones_like
(
vis_preds
[
-
1
],
device
=
vis_preds
[
-
1
].
device
),
)
else
:
train_data
=
None
return
coord_preds
[
-
1
][...,
:
2
],
vis_preds
[
-
1
],
confidence_preds
[
-
1
],
train_data
facebookresearch/co-tracker/cotracker/models/core/cotracker/cotracker3_online.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
cotracker.models.core.model_utils
import
sample_features5d
,
bilinear_sampler
from
cotracker.models.core.embeddings
import
get_1d_sincos_pos_embed_from_grid
from
cotracker.models.core.cotracker.blocks
import
Mlp
,
BasicEncoder
from
cotracker.models.core.cotracker.cotracker
import
EfficientUpdateFormer
torch
.
manual_seed
(
0
)
def
posenc
(
x
,
min_deg
,
max_deg
):
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
Instead of computing [sin(x), cos(x)], we use the trig identity
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
Args:
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
min_deg: int, the minimum (inclusive) degree of the encoding.
max_deg: int, the maximum (exclusive) degree of the encoding.
legacy_posenc_order: bool, keep the same ordering as the original tf code.
Returns:
encoded: torch.Tensor, encoded variables.
"""
if
min_deg
==
max_deg
:
return
x
scales
=
torch
.
tensor
(
[
2
**
i
for
i
in
range
(
min_deg
,
max_deg
)],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
xb
=
(
x
[...,
None
,
:]
*
scales
[:,
None
]).
reshape
(
list
(
x
.
shape
[:
-
1
])
+
[
-
1
])
four_feat
=
torch
.
sin
(
torch
.
cat
([
xb
,
xb
+
0.5
*
torch
.
pi
],
dim
=-
1
))
return
torch
.
cat
([
x
]
+
[
four_feat
],
dim
=-
1
)
class
CoTrackerThreeBase
(
nn
.
Module
):
def
__init__
(
self
,
window_len
=
8
,
stride
=
4
,
corr_radius
=
3
,
corr_levels
=
4
,
num_virtual_tracks
=
64
,
model_resolution
=
(
384
,
512
),
add_space_attn
=
True
,
linear_layer_for_vis_conf
=
True
,
):
super
(
CoTrackerThreeBase
,
self
).
__init__
()
self
.
window_len
=
window_len
self
.
stride
=
stride
self
.
corr_radius
=
corr_radius
self
.
corr_levels
=
corr_levels
self
.
hidden_dim
=
256
self
.
latent_dim
=
128
self
.
linear_layer_for_vis_conf
=
linear_layer_for_vis_conf
self
.
fnet
=
BasicEncoder
(
input_dim
=
3
,
output_dim
=
self
.
latent_dim
,
stride
=
stride
)
highres_dim
=
128
lowres_dim
=
256
self
.
num_virtual_tracks
=
num_virtual_tracks
self
.
model_resolution
=
model_resolution
self
.
input_dim
=
1110
self
.
updateformer
=
EfficientUpdateFormer
(
space_depth
=
3
,
time_depth
=
3
,
input_dim
=
self
.
input_dim
,
hidden_size
=
384
,
output_dim
=
4
,
mlp_ratio
=
4.0
,
num_virtual_tracks
=
num_virtual_tracks
,
add_space_attn
=
add_space_attn
,
linear_layer_for_vis_conf
=
linear_layer_for_vis_conf
,
)
self
.
corr_mlp
=
Mlp
(
in_features
=
49
*
49
,
hidden_features
=
384
,
out_features
=
256
)
time_grid
=
torch
.
linspace
(
0
,
window_len
-
1
,
window_len
).
reshape
(
1
,
window_len
,
1
)
self
.
register_buffer
(
"time_emb"
,
get_1d_sincos_pos_embed_from_grid
(
self
.
input_dim
,
time_grid
[
0
])
)
def
get_support_points
(
self
,
coords
,
r
,
reshape_back
=
True
):
B
,
_
,
N
,
_
=
coords
.
shape
device
=
coords
.
device
centroid_lvl
=
coords
.
reshape
(
B
,
N
,
1
,
1
,
3
)
dx
=
torch
.
linspace
(
-
r
,
r
,
2
*
r
+
1
,
device
=
device
)
dy
=
torch
.
linspace
(
-
r
,
r
,
2
*
r
+
1
,
device
=
device
)
xgrid
,
ygrid
=
torch
.
meshgrid
(
dy
,
dx
,
indexing
=
"ij"
)
zgrid
=
torch
.
zeros_like
(
xgrid
,
device
=
device
)
delta
=
torch
.
stack
([
zgrid
,
xgrid
,
ygrid
],
axis
=-
1
)
delta_lvl
=
delta
.
view
(
1
,
1
,
2
*
r
+
1
,
2
*
r
+
1
,
3
)
coords_lvl
=
centroid_lvl
+
delta_lvl
if
reshape_back
:
return
coords_lvl
.
reshape
(
B
,
N
,
(
2
*
r
+
1
)
**
2
,
3
).
permute
(
0
,
2
,
1
,
3
)
else
:
return
coords_lvl
def
get_track_feat
(
self
,
fmaps
,
queried_frames
,
queried_coords
,
support_radius
=
0
):
sample_frames
=
queried_frames
[:,
None
,
:,
None
]
sample_coords
=
torch
.
cat
(
[
sample_frames
,
queried_coords
[:,
None
],
],
dim
=-
1
,
)
support_points
=
self
.
get_support_points
(
sample_coords
,
support_radius
)
support_track_feats
=
sample_features5d
(
fmaps
,
support_points
)
return
(
support_track_feats
[:,
None
,
support_track_feats
.
shape
[
1
]
//
2
],
support_track_feats
,
)
def
get_correlation_feat
(
self
,
fmaps
,
queried_coords
):
B
,
T
,
D
,
H_
,
W_
=
fmaps
.
shape
N
=
queried_coords
.
shape
[
1
]
r
=
self
.
corr_radius
sample_coords
=
torch
.
cat
(
[
torch
.
zeros_like
(
queried_coords
[...,
:
1
]),
queried_coords
],
dim
=-
1
)[:,
None
]
support_points
=
self
.
get_support_points
(
sample_coords
,
r
,
reshape_back
=
False
)
correlation_feat
=
bilinear_sampler
(
fmaps
.
reshape
(
B
*
T
,
D
,
1
,
H_
,
W_
),
support_points
)
return
correlation_feat
.
view
(
B
,
T
,
D
,
N
,
(
2
*
r
+
1
),
(
2
*
r
+
1
)).
permute
(
0
,
1
,
3
,
4
,
5
,
2
)
def
interpolate_time_embed
(
self
,
x
,
t
):
previous_dtype
=
x
.
dtype
T
=
self
.
time_emb
.
shape
[
1
]
if
t
==
T
:
return
self
.
time_emb
time_emb
=
self
.
time_emb
.
float
()
time_emb
=
F
.
interpolate
(
time_emb
.
permute
(
0
,
2
,
1
),
size
=
t
,
mode
=
"linear"
).
permute
(
0
,
2
,
1
)
return
time_emb
.
to
(
previous_dtype
)
class
CoTrackerThreeOnline
(
CoTrackerThreeBase
):
def
__init__
(
self
,
**
args
):
super
(
CoTrackerThreeOnline
,
self
).
__init__
(
**
args
)
def
init_video_online_processing
(
self
):
self
.
online_ind
=
0
self
.
online_track_feat
=
[
None
]
*
self
.
corr_levels
self
.
online_track_support
=
[
None
]
*
self
.
corr_levels
self
.
online_coords_predicted
=
None
self
.
online_vis_predicted
=
None
self
.
online_conf_predicted
=
None
def
forward_window
(
self
,
fmaps_pyramid
,
coords
,
track_feat_support_pyramid
,
vis
=
None
,
conf
=
None
,
attention_mask
=
None
,
iters
=
4
,
add_space_attn
=
False
,
):
B
,
S
,
*
_
=
fmaps_pyramid
[
0
].
shape
N
=
coords
.
shape
[
2
]
r
=
2
*
self
.
corr_radius
+
1
coord_preds
,
vis_preds
,
conf_preds
=
[],
[],
[]
for
it
in
range
(
iters
):
coords
=
coords
.
detach
()
# B T N 2
coords_init
=
coords
.
view
(
B
*
S
,
N
,
2
)
corr_embs
=
[]
corr_feats
=
[]
for
i
in
range
(
self
.
corr_levels
):
corr_feat
=
self
.
get_correlation_feat
(
fmaps_pyramid
[
i
],
coords_init
/
2
**
i
)
track_feat_support
=
(
track_feat_support_pyramid
[
i
]
.
view
(
B
,
1
,
r
,
r
,
N
,
self
.
latent_dim
)
.
squeeze
(
1
)
.
permute
(
0
,
3
,
1
,
2
,
4
)
)
corr_volume
=
torch
.
einsum
(
"btnhwc,bnijc->btnhwij"
,
corr_feat
,
track_feat_support
)
corr_emb
=
self
.
corr_mlp
(
corr_volume
.
reshape
(
B
*
S
*
N
,
r
*
r
*
r
*
r
))
corr_embs
.
append
(
corr_emb
)
corr_embs
=
torch
.
cat
(
corr_embs
,
dim
=-
1
)
corr_embs
=
corr_embs
.
view
(
B
,
S
,
N
,
corr_embs
.
shape
[
-
1
])
transformer_input
=
[
vis
,
conf
,
corr_embs
]
rel_coords_forward
=
coords
[:,
:
-
1
]
-
coords
[:,
1
:]
rel_coords_backward
=
coords
[:,
1
:]
-
coords
[:,
:
-
1
]
rel_coords_forward
=
torch
.
nn
.
functional
.
pad
(
rel_coords_forward
,
(
0
,
0
,
0
,
0
,
0
,
1
)
)
rel_coords_backward
=
torch
.
nn
.
functional
.
pad
(
rel_coords_backward
,
(
0
,
0
,
0
,
0
,
1
,
0
)
)
scale
=
(
torch
.
tensor
(
[
self
.
model_resolution
[
1
],
self
.
model_resolution
[
0
]],
device
=
coords
.
device
,
)
/
self
.
stride
)
rel_coords_forward
=
rel_coords_forward
/
scale
rel_coords_backward
=
rel_coords_backward
/
scale
rel_pos_emb_input
=
posenc
(
torch
.
cat
([
rel_coords_forward
,
rel_coords_backward
],
dim
=-
1
),
min_deg
=
0
,
max_deg
=
10
,
)
# batch, num_points, num_frames, 84
transformer_input
.
append
(
rel_pos_emb_input
)
x
=
(
torch
.
cat
(
transformer_input
,
dim
=-
1
)
.
permute
(
0
,
2
,
1
,
3
)
.
reshape
(
B
*
N
,
S
,
-
1
)
)
x
=
x
+
self
.
interpolate_time_embed
(
x
,
S
)
x
=
x
.
view
(
B
,
N
,
S
,
-
1
)
# (B N) T D -> B N T D
delta
=
self
.
updateformer
(
x
,
add_space_attn
=
add_space_attn
)
delta_coords
=
delta
[...,
:
2
].
permute
(
0
,
2
,
1
,
3
)
delta_vis
=
delta
[...,
2
:
3
].
permute
(
0
,
2
,
1
,
3
)
delta_conf
=
delta
[...,
3
:].
permute
(
0
,
2
,
1
,
3
)
vis
=
vis
+
delta_vis
conf
=
conf
+
delta_conf
coords
=
coords
+
delta_coords
coord_preds
.
append
(
coords
[...,
:
2
]
*
float
(
self
.
stride
))
vis_preds
.
append
(
vis
[...,
0
])
conf_preds
.
append
(
conf
[...,
0
])
return
coord_preds
,
vis_preds
,
conf_preds
def
forward
(
self
,
video
,
queries
,
iters
=
4
,
is_train
=
False
,
add_space_attn
=
True
,
fmaps_chunk_size
=
200
,
is_online
=
False
,
):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B
,
T
,
C
,
H
,
W
=
video
.
shape
device
=
queries
.
device
assert
H
%
self
.
stride
==
0
and
W
%
self
.
stride
==
0
B
,
N
,
__
=
queries
.
shape
# B = batch size
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# video = B T C H W
# queries = B N 3
# coords_init = B T N 2
# vis_init = B T N 1
S
=
self
.
window_len
assert
S
>=
2
# A tracker needs at least two frames to track something
if
is_online
:
assert
T
<=
S
,
"Online mode: video chunk must be <= window size."
assert
(
self
.
online_ind
is
not
None
),
"Call model.init_video_online_processing() first."
assert
not
is_train
,
"Training not supported in online mode."
step
=
S
//
2
# How much the sliding window moves at every step
video
=
2
*
(
video
/
255.0
)
-
1.0
pad
=
(
S
-
T
if
is_online
else
(
S
-
T
%
S
)
%
S
)
# We don't want to pad if T % S == 0
video
=
video
.
reshape
(
B
,
1
,
T
,
C
*
H
*
W
)
if
pad
>
0
:
padding_tensor
=
video
[:,
:,
-
1
:,
:].
expand
(
B
,
1
,
pad
,
C
*
H
*
W
)
video
=
torch
.
cat
([
video
,
padding_tensor
],
dim
=
2
)
video
=
video
.
reshape
(
B
,
-
1
,
C
,
H
,
W
)
T_pad
=
video
.
shape
[
1
]
# The first channel is the frame number
# The rest are the coordinates of points we want to track
dtype
=
video
.
dtype
queried_frames
=
queries
[:,
:,
0
].
long
()
queried_coords
=
queries
[...,
1
:
3
]
queried_coords
=
queried_coords
/
self
.
stride
# We store our predictions here
coords_predicted
=
torch
.
zeros
((
B
,
T
,
N
,
2
),
device
=
device
)
vis_predicted
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
)
conf_predicted
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
)
if
is_online
:
if
self
.
online_coords_predicted
is
None
:
# Init online predictions with zeros
self
.
online_coords_predicted
=
coords_predicted
self
.
online_vis_predicted
=
vis_predicted
self
.
online_conf_predicted
=
conf_predicted
else
:
# Pad online predictions with zeros for the current window
pad
=
min
(
step
,
T
-
step
)
coords_predicted
=
F
.
pad
(
self
.
online_coords_predicted
,
(
0
,
0
,
0
,
0
,
0
,
pad
),
"constant"
)
vis_predicted
=
F
.
pad
(
self
.
online_vis_predicted
,
(
0
,
0
,
0
,
pad
),
"constant"
)
conf_predicted
=
F
.
pad
(
self
.
online_conf_predicted
,
(
0
,
0
,
0
,
pad
),
"constant"
)
# We store our predictions here
all_coords_predictions
,
all_vis_predictions
,
all_confidence_predictions
=
(
[],
[],
[],
)
C_
=
C
H4
,
W4
=
H
//
self
.
stride
,
W
//
self
.
stride
# Compute convolutional features for the video or for the current chunk in case of online mode
if
(
not
is_train
)
and
(
T
>
fmaps_chunk_size
):
fmaps
=
[]
for
t
in
range
(
0
,
T
,
fmaps_chunk_size
):
video_chunk
=
video
[:,
t
:
t
+
fmaps_chunk_size
]
fmaps_chunk
=
self
.
fnet
(
video_chunk
.
reshape
(
-
1
,
C_
,
H
,
W
))
T_chunk
=
video_chunk
.
shape
[
1
]
C_chunk
,
H_chunk
,
W_chunk
=
fmaps_chunk
.
shape
[
1
:]
fmaps
.
append
(
fmaps_chunk
.
reshape
(
B
,
T_chunk
,
C_chunk
,
H_chunk
,
W_chunk
))
fmaps
=
torch
.
cat
(
fmaps
,
dim
=
1
).
reshape
(
-
1
,
C_chunk
,
H_chunk
,
W_chunk
)
else
:
fmaps
=
self
.
fnet
(
video
.
reshape
(
-
1
,
C_
,
H
,
W
))
fmaps
=
fmaps
.
permute
(
0
,
2
,
3
,
1
)
fmaps
=
fmaps
/
torch
.
sqrt
(
torch
.
maximum
(
torch
.
sum
(
torch
.
square
(
fmaps
),
axis
=-
1
,
keepdims
=
True
),
torch
.
tensor
(
1e-12
,
device
=
fmaps
.
device
),
)
)
fmaps
=
fmaps
.
permute
(
0
,
3
,
1
,
2
).
reshape
(
B
,
-
1
,
self
.
latent_dim
,
H
//
self
.
stride
,
W
//
self
.
stride
)
fmaps
=
fmaps
.
to
(
dtype
)
# We compute track features
fmaps_pyramid
=
[]
track_feat_pyramid
=
[]
track_feat_support_pyramid
=
[]
fmaps_pyramid
.
append
(
fmaps
)
for
i
in
range
(
self
.
corr_levels
-
1
):
fmaps_
=
fmaps
.
reshape
(
B
*
T_pad
,
self
.
latent_dim
,
fmaps
.
shape
[
-
2
],
fmaps
.
shape
[
-
1
]
)
fmaps_
=
F
.
avg_pool2d
(
fmaps_
,
2
,
stride
=
2
)
fmaps
=
fmaps_
.
reshape
(
B
,
T_pad
,
self
.
latent_dim
,
fmaps_
.
shape
[
-
2
],
fmaps_
.
shape
[
-
1
]
)
fmaps_pyramid
.
append
(
fmaps
)
if
is_online
:
sample_frames
=
queried_frames
[:,
None
,
:,
None
]
# B 1 N 1
left
=
0
if
self
.
online_ind
==
0
else
self
.
online_ind
+
step
right
=
self
.
online_ind
+
S
sample_mask
=
(
sample_frames
>=
left
)
&
(
sample_frames
<
right
)
for
i
in
range
(
self
.
corr_levels
):
track_feat
,
track_feat_support
=
self
.
get_track_feat
(
fmaps_pyramid
[
i
],
queried_frames
-
self
.
online_ind
if
is_online
else
queried_frames
,
queried_coords
/
2
**
i
,
support_radius
=
self
.
corr_radius
,
)
if
is_online
:
if
self
.
online_track_feat
[
i
]
is
None
:
self
.
online_track_feat
[
i
]
=
torch
.
zeros_like
(
track_feat
,
device
=
device
)
self
.
online_track_support
[
i
]
=
torch
.
zeros_like
(
track_feat_support
,
device
=
device
)
self
.
online_track_feat
[
i
]
+=
track_feat
*
sample_mask
self
.
online_track_support
[
i
]
+=
track_feat_support
*
sample_mask
track_feat_pyramid
.
append
(
self
.
online_track_feat
[
i
].
repeat
(
1
,
T_pad
,
1
,
1
)
)
track_feat_support_pyramid
.
append
(
self
.
online_track_support
[
i
].
unsqueeze
(
1
)
)
else
:
track_feat_pyramid
.
append
(
track_feat
.
repeat
(
1
,
T_pad
,
1
,
1
))
track_feat_support_pyramid
.
append
(
track_feat_support
.
unsqueeze
(
1
))
D_coords
=
2
coord_preds
,
vis_preds
,
confidence_preds
=
[],
[],
[]
vis_init
=
torch
.
zeros
((
B
,
S
,
N
,
1
),
device
=
device
).
float
()
conf_init
=
torch
.
zeros
((
B
,
S
,
N
,
1
),
device
=
device
).
float
()
coords_init
=
queried_coords
.
reshape
(
B
,
1
,
N
,
2
).
expand
(
B
,
S
,
N
,
2
).
float
()
num_windows
=
(
T
-
S
+
step
-
1
)
//
step
+
1
# We process only the current video chunk in the online mode
indices
=
[
self
.
online_ind
]
if
is_online
else
range
(
0
,
step
*
num_windows
,
step
)
for
ind
in
indices
:
if
ind
>
0
:
overlap
=
S
-
step
copy_over
=
(
queried_frames
<
ind
+
overlap
)[
:,
None
,
:,
None
]
# B 1 N 1
coords_prev
=
coords_predicted
[:,
ind
:
ind
+
overlap
]
/
self
.
stride
padding_tensor
=
coords_prev
[:,
-
1
:,
:,
:].
expand
(
-
1
,
step
,
-
1
,
-
1
)
coords_prev
=
torch
.
cat
([
coords_prev
,
padding_tensor
],
dim
=
1
)
vis_prev
=
vis_predicted
[:,
ind
:
ind
+
overlap
,
:,
None
].
clone
()
padding_tensor
=
vis_prev
[:,
-
1
:,
:,
:].
expand
(
-
1
,
step
,
-
1
,
-
1
)
vis_prev
=
torch
.
cat
([
vis_prev
,
padding_tensor
],
dim
=
1
)
conf_prev
=
conf_predicted
[:,
ind
:
ind
+
overlap
,
:,
None
].
clone
()
padding_tensor
=
conf_prev
[:,
-
1
:,
:,
:].
expand
(
-
1
,
step
,
-
1
,
-
1
)
conf_prev
=
torch
.
cat
([
conf_prev
,
padding_tensor
],
dim
=
1
)
coords_init
=
torch
.
where
(
copy_over
.
expand_as
(
coords_init
),
coords_prev
,
coords_init
)
vis_init
=
torch
.
where
(
copy_over
.
expand_as
(
vis_init
),
vis_prev
,
vis_init
)
conf_init
=
torch
.
where
(
copy_over
.
expand_as
(
conf_init
),
conf_prev
,
conf_init
)
attention_mask
=
(
queried_frames
<
ind
+
S
).
reshape
(
B
,
1
,
N
)
# B S N
# import ipdb; ipdb.set_trace()
coords
,
viss
,
confs
=
self
.
forward_window
(
fmaps_pyramid
=
(
fmaps_pyramid
if
is_online
else
[
fmap
[:,
ind
:
ind
+
S
]
for
fmap
in
fmaps_pyramid
]
),
coords
=
coords_init
,
track_feat_support_pyramid
=
[
attention_mask
[:,
None
,
:,
:,
None
]
*
tfeat
for
tfeat
in
track_feat_support_pyramid
],
vis
=
vis_init
,
conf
=
conf_init
,
attention_mask
=
attention_mask
.
repeat
(
1
,
S
,
1
),
iters
=
iters
,
add_space_attn
=
add_space_attn
,
)
S_trimmed
=
(
T
if
is_online
else
min
(
T
-
ind
,
S
)
)
# accounts for last window duration
coords_predicted
[:,
ind
:
ind
+
S
]
=
coords
[
-
1
][:,
:
S_trimmed
]
vis_predicted
[:,
ind
:
ind
+
S
]
=
viss
[
-
1
][:,
:
S_trimmed
]
conf_predicted
[:,
ind
:
ind
+
S
]
=
confs
[
-
1
][:,
:
S_trimmed
]
if
is_train
:
all_coords_predictions
.
append
(
[
coord
[:,
:
S_trimmed
]
for
coord
in
coords
]
)
all_vis_predictions
.
append
(
[
torch
.
sigmoid
(
vis
[:,
:
S_trimmed
])
for
vis
in
viss
]
)
all_confidence_predictions
.
append
(
[
torch
.
sigmoid
(
conf
[:,
:
S_trimmed
])
for
conf
in
confs
]
)
if
is_online
:
self
.
online_ind
+=
step
self
.
online_coords_predicted
=
coords_predicted
self
.
online_vis_predicted
=
vis_predicted
self
.
online_conf_predicted
=
conf_predicted
vis_predicted
=
torch
.
sigmoid
(
vis_predicted
)
conf_predicted
=
torch
.
sigmoid
(
conf_predicted
)
if
is_train
:
valid_mask
=
(
queried_frames
[:,
None
]
<=
torch
.
arange
(
0
,
T
,
device
=
device
)[
None
,
:,
None
]
)
train_data
=
(
all_coords_predictions
,
all_vis_predictions
,
all_confidence_predictions
,
valid_mask
,
)
else
:
train_data
=
None
return
coords_predicted
,
vis_predicted
,
conf_predicted
,
train_data
facebookresearch/co-tracker/cotracker/models/core/cotracker/losses.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn.functional
as
F
from
cotracker.models.core.model_utils
import
reduce_masked_mean
import
torch.nn
as
nn
from
typing
import
List
def
sequence_loss
(
flow_preds
,
flow_gt
,
valids
,
vis
=
None
,
gamma
=
0.8
,
add_huber_loss
=
False
,
loss_only_for_visible
=
False
,
):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss
=
0.0
for
j
in
range
(
len
(
flow_gt
)):
B
,
S
,
N
,
D
=
flow_gt
[
j
].
shape
B
,
S2
,
N
=
valids
[
j
].
shape
assert
S
==
S2
n_predictions
=
len
(
flow_preds
[
j
])
flow_loss
=
0.0
for
i
in
range
(
n_predictions
):
i_weight
=
gamma
**
(
n_predictions
-
i
-
1
)
flow_pred
=
flow_preds
[
j
][
i
]
if
add_huber_loss
:
i_loss
=
huber_loss
(
flow_pred
,
flow_gt
[
j
],
delta
=
6.0
)
else
:
i_loss
=
(
flow_pred
-
flow_gt
[
j
]).
abs
()
# B, S, N, 2
i_loss
=
torch
.
mean
(
i_loss
,
dim
=
3
)
# B, S, N
valid_
=
valids
[
j
].
clone
()
if
loss_only_for_visible
:
valid_
=
valid_
*
vis
[
j
]
flow_loss
+=
i_weight
*
reduce_masked_mean
(
i_loss
,
valid_
)
flow_loss
=
flow_loss
/
n_predictions
total_flow_loss
+=
flow_loss
return
total_flow_loss
/
len
(
flow_gt
)
def
huber_loss
(
x
,
y
,
delta
=
1.0
):
"""Calculate element-wise Huber loss between x and y"""
diff
=
x
-
y
abs_diff
=
diff
.
abs
()
flag
=
(
abs_diff
<=
delta
).
float
()
return
flag
*
0.5
*
diff
**
2
+
(
1
-
flag
)
*
delta
*
(
abs_diff
-
0.5
*
delta
)
def
sequence_BCE_loss
(
vis_preds
,
vis_gts
):
total_bce_loss
=
0.0
for
j
in
range
(
len
(
vis_preds
)):
n_predictions
=
len
(
vis_preds
[
j
])
bce_loss
=
0.0
for
i
in
range
(
n_predictions
):
vis_loss
=
F
.
binary_cross_entropy
(
vis_preds
[
j
][
i
],
vis_gts
[
j
])
bce_loss
+=
vis_loss
bce_loss
=
bce_loss
/
n_predictions
total_bce_loss
+=
bce_loss
return
total_bce_loss
/
len
(
vis_preds
)
def
sequence_prob_loss
(
tracks
:
torch
.
Tensor
,
confidence
:
torch
.
Tensor
,
target_points
:
torch
.
Tensor
,
visibility
:
torch
.
Tensor
,
expected_dist_thresh
:
float
=
12.0
,
):
"""Loss for classifying if a point is within pixel threshold of its target."""
# Points with an error larger than 12 pixels are likely to be useless; marking
# them as occluded will actually improve Jaccard metrics and give
# qualitatively better results.
total_logprob_loss
=
0.0
for
j
in
range
(
len
(
tracks
)):
n_predictions
=
len
(
tracks
[
j
])
logprob_loss
=
0.0
for
i
in
range
(
n_predictions
):
err
=
torch
.
sum
((
tracks
[
j
][
i
].
detach
()
-
target_points
[
j
])
**
2
,
dim
=-
1
)
valid
=
(
err
<=
expected_dist_thresh
**
2
).
float
()
logprob
=
F
.
binary_cross_entropy
(
confidence
[
j
][
i
],
valid
,
reduction
=
"none"
)
logprob
*=
visibility
[
j
]
logprob
=
torch
.
mean
(
logprob
,
dim
=
[
1
,
2
])
logprob_loss
+=
logprob
logprob_loss
=
logprob_loss
/
n_predictions
total_logprob_loss
+=
logprob_loss
return
total_logprob_loss
/
len
(
tracks
)
def
masked_mean
(
data
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
|
None
,
dim
:
List
[
int
]):
if
mask
is
None
:
return
data
.
mean
(
dim
=
dim
,
keepdim
=
True
)
mask
=
mask
.
float
()
mask_sum
=
torch
.
sum
(
mask
,
dim
=
dim
,
keepdim
=
True
)
mask_mean
=
torch
.
sum
(
data
*
mask
,
dim
=
dim
,
keepdim
=
True
)
/
torch
.
clamp
(
mask_sum
,
min
=
1.0
)
return
mask_mean
def
masked_mean_var
(
data
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
dim
:
List
[
int
]):
if
mask
is
None
:
return
data
.
mean
(
dim
=
dim
,
keepdim
=
True
),
data
.
var
(
dim
=
dim
,
keepdim
=
True
)
mask
=
mask
.
float
()
mask_sum
=
torch
.
sum
(
mask
,
dim
=
dim
,
keepdim
=
True
)
mask_mean
=
torch
.
sum
(
data
*
mask
,
dim
=
dim
,
keepdim
=
True
)
/
torch
.
clamp
(
mask_sum
,
min
=
1.0
)
mask_var
=
torch
.
sum
(
mask
*
(
data
-
mask_mean
)
**
2
,
dim
=
dim
,
keepdim
=
True
)
/
torch
.
clamp
(
mask_sum
,
min
=
1.0
)
return
mask_mean
.
squeeze
(
dim
),
mask_var
.
squeeze
(
dim
)
facebookresearch/co-tracker/cotracker/models/core/embeddings.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Tuple
,
Union
import
torch
def
get_2d_sincos_pos_embed
(
embed_dim
:
int
,
grid_size
:
Union
[
int
,
Tuple
[
int
,
int
]]
)
->
torch
.
Tensor
:
"""
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
Args:
- embed_dim: The embedding dimension.
- grid_size: The grid size.
Returns:
- pos_embed: The generated 2D positional embedding.
"""
if
isinstance
(
grid_size
,
tuple
):
grid_size_h
,
grid_size_w
=
grid_size
else
:
grid_size_h
=
grid_size_w
=
grid_size
grid_h
=
torch
.
arange
(
grid_size_h
,
dtype
=
torch
.
float
)
grid_w
=
torch
.
arange
(
grid_size_w
,
dtype
=
torch
.
float
)
grid
=
torch
.
meshgrid
(
grid_w
,
grid_h
,
indexing
=
"xy"
)
grid
=
torch
.
stack
(
grid
,
dim
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size_h
,
grid_size_w
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
return
pos_embed
.
reshape
(
1
,
grid_size_h
,
grid_size_w
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
grid
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- grid: The grid to generate the embedding from.
Returns:
- emb: The generated 2D positional embedding.
"""
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
torch
.
cat
([
emb_h
,
emb_w
],
dim
=
2
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
pos
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert
embed_dim
%
2
==
0
omega
=
torch
.
arange
(
embed_dim
//
2
,
dtype
=
torch
.
double
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
torch
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
torch
.
sin
(
out
)
# (M, D/2)
emb_cos
=
torch
.
cos
(
out
)
# (M, D/2)
emb
=
torch
.
cat
([
emb_sin
,
emb_cos
],
dim
=
1
)
# (M, D)
return
emb
[
None
].
float
()
def
get_2d_embedding
(
xy
:
torch
.
Tensor
,
C
:
int
,
cat_coords
:
bool
=
True
)
->
torch
.
Tensor
:
"""
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
Args:
- xy: The coordinates to generate the embedding from.
- C: The size of the embedding.
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
Returns:
- pe: The generated 2D positional embedding.
"""
B
,
N
,
D
=
xy
.
shape
assert
D
==
2
x
=
xy
[:,
:,
0
:
1
]
y
=
xy
[:,
:,
1
:
2
]
div_term
=
(
torch
.
arange
(
0
,
C
,
2
,
device
=
xy
.
device
,
dtype
=
torch
.
float32
)
*
(
1000.0
/
C
)
).
reshape
(
1
,
1
,
int
(
C
/
2
))
pe_x
=
torch
.
zeros
(
B
,
N
,
C
,
device
=
xy
.
device
,
dtype
=
torch
.
float32
)
pe_y
=
torch
.
zeros
(
B
,
N
,
C
,
device
=
xy
.
device
,
dtype
=
torch
.
float32
)
pe_x
[:,
:,
0
::
2
]
=
torch
.
sin
(
x
*
div_term
)
pe_x
[:,
:,
1
::
2
]
=
torch
.
cos
(
x
*
div_term
)
pe_y
[:,
:,
0
::
2
]
=
torch
.
sin
(
y
*
div_term
)
pe_y
[:,
:,
1
::
2
]
=
torch
.
cos
(
y
*
div_term
)
pe
=
torch
.
cat
([
pe_x
,
pe_y
],
dim
=
2
)
# (B, N, C*3)
if
cat_coords
:
pe
=
torch
.
cat
([
xy
,
pe
],
dim
=
2
)
# (B, N, C*3+3)
return
pe
facebookresearch/co-tracker/cotracker/models/core/model_utils.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
random
import
torch
import
torch.nn.functional
as
F
from
typing
import
Optional
,
Tuple
EPS
=
1e-6
def
smart_cat
(
tensor1
,
tensor2
,
dim
):
if
tensor1
is
None
:
return
tensor2
return
torch
.
cat
([
tensor1
,
tensor2
],
dim
=
dim
)
def
get_uniformly_sampled_pts
(
size
:
int
,
num_frames
:
int
,
extent
:
Tuple
[
float
,
...],
device
:
Optional
[
torch
.
device
]
=
torch
.
device
(
"cpu"
),
):
time_points
=
torch
.
randint
(
low
=
0
,
high
=
num_frames
,
size
=
(
size
,
1
),
device
=
device
)
space_points
=
torch
.
rand
(
size
,
2
,
device
=
device
)
*
torch
.
tensor
(
[
extent
[
1
],
extent
[
0
]],
device
=
device
)
points
=
torch
.
cat
((
time_points
,
space_points
),
dim
=
1
)
return
points
[
None
]
def
get_superpoint_sampled_pts
(
video
,
size
:
int
,
num_frames
:
int
,
extent
:
Tuple
[
float
,
...],
device
:
Optional
[
torch
.
device
]
=
torch
.
device
(
"cpu"
),
):
extractor
=
SuperPoint
(
max_num_keypoints
=
48
).
eval
().
cuda
()
points
=
list
()
for
_
in
range
(
8
):
frame_num
=
random
.
randint
(
0
,
int
(
num_frames
*
0.25
))
key_points
=
extractor
.
extract
(
video
[
0
,
frame_num
,
:,
:,
:]
/
255.0
,
resize
=
None
)[
"keypoints"
]
frame_tensor
=
torch
.
full
((
1
,
key_points
.
shape
[
1
],
1
),
frame_num
).
cuda
()
points
.
append
(
torch
.
cat
([
frame_tensor
.
cuda
(),
key_points
],
dim
=
2
))
return
torch
.
cat
(
points
,
dim
=
1
)[:,
:
size
,
:]
def
get_sift_sampled_pts
(
video
,
size
:
int
,
num_frames
:
int
,
extent
:
Tuple
[
float
,
...],
device
:
Optional
[
torch
.
device
]
=
torch
.
device
(
"cpu"
),
num_sampled_frames
:
int
=
8
,
sampling_length_percent
:
float
=
0.25
,
):
import
cv2
# assert size == 384, "hardcoded for experiment"
sift
=
cv2
.
SIFT_create
(
nfeatures
=
size
//
num_sampled_frames
)
points
=
list
()
for
_
in
range
(
num_sampled_frames
):
frame_num
=
random
.
randint
(
0
,
int
(
num_frames
*
sampling_length_percent
))
key_points
,
_
=
sift
.
detectAndCompute
(
video
[
0
,
frame_num
,
:,
:,
:]
.
cpu
()
.
permute
(
1
,
2
,
0
)
.
numpy
()
.
astype
(
np
.
uint8
),
None
,
)
for
kp
in
key_points
:
points
.
append
([
frame_num
,
int
(
kp
.
pt
[
0
]),
int
(
kp
.
pt
[
1
])])
return
torch
.
tensor
(
points
[:
size
],
device
=
device
)[
None
]
def
get_points_on_a_grid
(
size
:
int
,
extent
:
Tuple
[
float
,
...],
center
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
device
:
Optional
[
torch
.
device
]
=
torch
.
device
(
"cpu"
),
):
r
"""Get a grid of points covering a rectangular region
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
:attr:`size` grid fo points distributed to cover a rectangular area
specified by `extent`.
The `extent` is a pair of integer :math:`(H,W)` specifying the height
and width of the rectangle.
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
specifying the vertical and horizontal center coordinates. The center
defaults to the middle of the extent.
Points are distributed uniformly within the rectangle leaving a margin
:math:`m=W/64` from the border.
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
points :math:`P_{ij}=(x_i, y_i)` where
.. math::
P_{ij} = \left(
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
\right)
Points are returned in row-major order.
Args:
size (int): grid size.
extent (tuple): height and with of the grid extent.
center (tuple, optional): grid center.
device (str, optional): Defaults to `"cpu"`.
Returns:
Tensor: grid.
"""
if
size
==
1
:
return
torch
.
tensor
([
extent
[
1
]
/
2
,
extent
[
0
]
/
2
],
device
=
device
)[
None
,
None
]
if
center
is
None
:
center
=
[
extent
[
0
]
/
2
,
extent
[
1
]
/
2
]
margin
=
extent
[
1
]
/
64
range_y
=
(
margin
-
extent
[
0
]
/
2
+
center
[
0
],
extent
[
0
]
/
2
+
center
[
0
]
-
margin
)
range_x
=
(
margin
-
extent
[
1
]
/
2
+
center
[
1
],
extent
[
1
]
/
2
+
center
[
1
]
-
margin
)
grid_y
,
grid_x
=
torch
.
meshgrid
(
torch
.
linspace
(
*
range_y
,
size
,
device
=
device
),
torch
.
linspace
(
*
range_x
,
size
,
device
=
device
),
indexing
=
"ij"
,
)
return
torch
.
stack
([
grid_x
,
grid_y
],
dim
=-
1
).
reshape
(
1
,
-
1
,
2
)
def
reduce_masked_mean
(
input
,
mask
,
dim
=
None
,
keepdim
=
False
):
r
"""Masked mean
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
over a mask :attr:`mask`, returning
.. math::
\text{output} =
\frac
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
{\epsilon + \sum_{i=1}^N \text{mask}_i}
where :math:`N` is the number of elements in :attr:`input` and
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
division by zero.
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
Optionally, the dimension can be kept in the output by setting
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
the same dimension as :attr:`input`.
The interface is similar to `torch.mean()`.
Args:
inout (Tensor): input tensor.
mask (Tensor): mask.
dim (int, optional): Dimension to sum over. Defaults to None.
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
Returns:
Tensor: mean tensor.
"""
mask
=
mask
.
expand_as
(
input
)
prod
=
input
*
mask
if
dim
is
None
:
numer
=
torch
.
sum
(
prod
)
denom
=
torch
.
sum
(
mask
)
else
:
numer
=
torch
.
sum
(
prod
,
dim
=
dim
,
keepdim
=
keepdim
)
denom
=
torch
.
sum
(
mask
,
dim
=
dim
,
keepdim
=
keepdim
)
mean
=
numer
/
(
EPS
+
denom
)
return
mean
def
bilinear_sampler
(
input
,
coords
,
align_corners
=
True
,
padding_mode
=
"border"
):
r
"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes
=
input
.
shape
[
2
:]
assert
len
(
sizes
)
in
[
2
,
3
]
if
len
(
sizes
)
==
3
:
# t x y -> x y t to match dimensions T H W in grid_sample
coords
=
coords
[...,
[
1
,
2
,
0
]]
if
align_corners
:
coords
=
coords
*
torch
.
tensor
(
[
2
/
max
(
size
-
1
,
1
)
for
size
in
reversed
(
sizes
)],
device
=
coords
.
device
)
else
:
coords
=
coords
*
torch
.
tensor
(
[
2
/
size
for
size
in
reversed
(
sizes
)],
device
=
coords
.
device
)
coords
-=
1
return
F
.
grid_sample
(
input
,
coords
,
align_corners
=
align_corners
,
padding_mode
=
padding_mode
)
def
sample_features4d
(
input
,
coords
):
r
"""Sample spatial features
`sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B,
R, C)`.
Args:
input (Tensor): spatial features.
coords (Tensor): points.
Returns:
Tensor: sampled features.
"""
B
,
_
,
_
,
_
=
input
.
shape
# B R 2 -> B R 1 2
coords
=
coords
.
unsqueeze
(
2
)
# B C R 1
feats
=
bilinear_sampler
(
input
,
coords
)
return
feats
.
permute
(
0
,
2
,
1
,
3
).
view
(
B
,
-
1
,
feats
.
shape
[
1
]
*
feats
.
shape
[
3
]
)
# B C R 1 -> B R C
def
sample_features5d
(
input
,
coords
):
r
"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args:
input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points.
Returns:
Tensor: sampled features.
"""
B
,
T
,
_
,
_
,
_
=
input
.
shape
# B T C H W -> B C T H W
input
=
input
.
permute
(
0
,
2
,
1
,
3
,
4
)
# B R1 R2 3 -> B R1 R2 1 3
coords
=
coords
.
unsqueeze
(
3
)
# B C R1 R2 1
feats
=
bilinear_sampler
(
input
,
coords
)
return
feats
.
permute
(
0
,
2
,
3
,
1
,
4
).
view
(
B
,
feats
.
shape
[
2
],
feats
.
shape
[
3
],
feats
.
shape
[
1
]
)
# B C R1 R2 1 -> B R1 R2 C
def
get_grid
(
height
,
width
,
shape
=
None
,
dtype
=
"torch"
,
device
=
"cpu"
,
align_corners
=
True
,
normalize
=
True
,
):
H
,
W
=
height
,
width
S
=
shape
if
shape
else
[]
if
align_corners
:
x
=
torch
.
linspace
(
0
,
1
,
W
,
device
=
device
)
y
=
torch
.
linspace
(
0
,
1
,
H
,
device
=
device
)
if
not
normalize
:
x
=
x
*
(
W
-
1
)
y
=
y
*
(
H
-
1
)
else
:
x
=
torch
.
linspace
(
0.5
/
W
,
1.0
-
0.5
/
W
,
W
,
device
=
device
)
y
=
torch
.
linspace
(
0.5
/
H
,
1.0
-
0.5
/
H
,
H
,
device
=
device
)
if
not
normalize
:
x
=
x
*
W
y
=
y
*
H
x_view
,
y_view
,
exp
=
[
1
for
_
in
S
]
+
[
1
,
-
1
],
[
1
for
_
in
S
]
+
[
-
1
,
1
],
S
+
[
H
,
W
]
x
=
x
.
view
(
*
x_view
).
expand
(
*
exp
)
y
=
y
.
view
(
*
y_view
).
expand
(
*
exp
)
grid
=
torch
.
stack
([
x
,
y
],
dim
=-
1
)
if
dtype
==
"numpy"
:
grid
=
grid
.
numpy
()
return
grid
def
bilinear_sampler
(
input
,
coords
,
align_corners
=
True
,
padding_mode
=
"border"
):
r
"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes
=
input
.
shape
[
2
:]
assert
len
(
sizes
)
in
[
2
,
3
]
if
len
(
sizes
)
==
3
:
# t x y -> x y t to match dimensions T H W in grid_sample
coords
=
coords
[...,
[
1
,
2
,
0
]]
if
align_corners
:
coords
=
coords
*
torch
.
tensor
(
[
2
/
max
(
size
-
1
,
1
)
for
size
in
reversed
(
sizes
)],
device
=
coords
.
device
)
else
:
coords
=
coords
*
torch
.
tensor
(
[
2
/
size
for
size
in
reversed
(
sizes
)],
device
=
coords
.
device
)
coords
-=
1
return
F
.
grid_sample
(
input
,
coords
,
align_corners
=
align_corners
,
padding_mode
=
padding_mode
)
def
round_to_multiple_of_4
(
n
):
return
round
(
n
/
4
)
*
4
facebookresearch/co-tracker/cotracker/models/evaluation_predictor.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn.functional
as
F
from
typing
import
Tuple
from
cotracker.models.core.cotracker.cotracker3_offline
import
CoTrackerThreeOffline
from
cotracker.models.core.model_utils
import
(
get_points_on_a_grid
,
get_uniformly_sampled_pts
,
get_sift_sampled_pts
,
)
import
numpy
as
np
import
sys
from
torchvision.transforms
import
Compose
from
tqdm
import
tqdm
from
cotracker.models.core.model_utils
import
bilinear_sampler
class
EvaluationPredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
cotracker_model
:
CoTrackerThreeOffline
,
interp_shape
:
Tuple
[
int
,
int
]
=
(
384
,
512
),
grid_size
:
int
=
5
,
local_grid_size
:
int
=
8
,
single_point
:
bool
=
True
,
sift_size
:
int
=
0
,
num_uniformly_sampled_pts
:
int
=
0
,
n_iters
:
int
=
6
,
local_extent
:
int
=
50
,
)
->
None
:
super
(
EvaluationPredictor
,
self
).
__init__
()
self
.
grid_size
=
grid_size
self
.
local_grid_size
=
local_grid_size
self
.
sift_size
=
sift_size
self
.
single_point
=
single_point
self
.
interp_shape
=
interp_shape
self
.
n_iters
=
n_iters
self
.
num_uniformly_sampled_pts
=
num_uniformly_sampled_pts
self
.
model
=
cotracker_model
self
.
local_extent
=
local_extent
self
.
model
.
eval
()
def
forward
(
self
,
video
,
queries
):
queries
=
queries
.
clone
()
B
,
T
,
C
,
H
,
W
=
video
.
shape
B
,
N
,
D
=
queries
.
shape
assert
D
==
3
assert
B
==
1
interp_shape
=
self
.
interp_shape
video
=
video
.
reshape
(
B
*
T
,
C
,
H
,
W
)
video
=
F
.
interpolate
(
video
,
tuple
(
interp_shape
),
mode
=
"bilinear"
,
align_corners
=
True
)
video
=
video
.
reshape
(
B
,
T
,
3
,
interp_shape
[
0
],
interp_shape
[
1
])
device
=
video
.
device
queries
[:,
:,
1
]
*=
(
interp_shape
[
1
]
-
1
)
/
(
W
-
1
)
queries
[:,
:,
2
]
*=
(
interp_shape
[
0
]
-
1
)
/
(
H
-
1
)
if
self
.
single_point
:
traj_e
=
torch
.
zeros
((
B
,
T
,
N
,
2
),
device
=
device
)
vis_e
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
)
conf_e
=
torch
.
zeros
((
B
,
T
,
N
),
device
=
device
)
for
pind
in
range
((
N
)):
query
=
queries
[:,
pind
:
pind
+
1
]
t
=
query
[
0
,
0
,
0
].
long
()
start_ind
=
0
traj_e_pind
,
vis_e_pind
,
conf_e_pind
=
self
.
_process_one_point
(
video
[:,
start_ind
:],
query
)
traj_e
[:,
start_ind
:,
pind
:
pind
+
1
]
=
traj_e_pind
[:,
:,
:
1
]
vis_e
[:,
start_ind
:,
pind
:
pind
+
1
]
=
vis_e_pind
[:,
:,
:
1
]
conf_e
[:,
start_ind
:,
pind
:
pind
+
1
]
=
conf_e_pind
[:,
:,
:
1
]
else
:
if
self
.
grid_size
>
0
:
xy
=
get_points_on_a_grid
(
self
.
grid_size
,
video
.
shape
[
3
:])
xy
=
torch
.
cat
([
torch
.
zeros_like
(
xy
[:,
:,
:
1
]),
xy
],
dim
=
2
).
to
(
device
)
#
queries
=
torch
.
cat
([
queries
,
xy
],
dim
=
1
)
#
if
self
.
num_uniformly_sampled_pts
>
0
:
xy
=
get_uniformly_sampled_pts
(
self
.
num_uniformly_sampled_pts
,
video
.
shape
[
1
],
video
.
shape
[
3
:],
device
=
device
,
)
queries
=
torch
.
cat
([
queries
,
xy
],
dim
=
1
)
#
sift_size
=
self
.
sift_size
if
sift_size
>
0
:
xy
=
get_sift_sampled_pts
(
video
,
sift_size
,
T
,
[
H
,
W
],
device
=
device
)
if
xy
.
shape
[
1
]
==
sift_size
:
queries
=
torch
.
cat
([
queries
,
xy
],
dim
=
1
)
#
else
:
sift_size
=
0
preds
=
self
.
model
(
video
=
video
,
queries
=
queries
,
iters
=
self
.
n_iters
)
traj_e
,
vis_e
=
preds
[
0
],
preds
[
1
]
conf_e
=
None
if
len
(
preds
)
>
3
:
conf_e
=
preds
[
2
]
if
(
sift_size
>
0
or
self
.
grid_size
>
0
or
self
.
num_uniformly_sampled_pts
>
0
):
traj_e
=
traj_e
[
:,
:,
:
-
self
.
grid_size
**
2
-
sift_size
-
self
.
num_uniformly_sampled_pts
,
]
vis_e
=
vis_e
[
:,
:,
:
-
self
.
grid_size
**
2
-
sift_size
-
self
.
num_uniformly_sampled_pts
,
]
if
conf_e
is
not
None
:
conf_e
=
conf_e
[
:,
:,
:
-
self
.
grid_size
**
2
-
sift_size
-
self
.
num_uniformly_sampled_pts
,
]
traj_e
[:,
:,
:,
0
]
*=
(
W
-
1
)
/
float
(
interp_shape
[
1
]
-
1
)
traj_e
[:,
:,
:,
1
]
*=
(
H
-
1
)
/
float
(
interp_shape
[
0
]
-
1
)
if
conf_e
is
not
None
:
vis_e
=
vis_e
*
conf_e
return
traj_e
,
vis_e
def
_process_one_point
(
self
,
video
,
query
):
t
=
query
[
0
,
0
,
0
].
long
()
B
,
T
,
C
,
H
,
W
=
video
.
shape
device
=
query
.
device
if
self
.
local_grid_size
>
0
:
xy_target
=
get_points_on_a_grid
(
self
.
local_grid_size
,
(
self
.
local_extent
,
self
.
local_extent
),
[
query
[
0
,
0
,
2
].
item
(),
query
[
0
,
0
,
1
].
item
()],
)
xy_target
=
torch
.
cat
(
[
torch
.
zeros_like
(
xy_target
[:,
:,
:
1
]),
xy_target
],
dim
=
2
).
to
(
device
)
#
query
=
torch
.
cat
([
query
,
xy_target
],
dim
=
1
)
#
if
self
.
grid_size
>
0
:
xy
=
get_points_on_a_grid
(
self
.
grid_size
,
video
.
shape
[
3
:])
xy
=
torch
.
cat
([
torch
.
zeros_like
(
xy
[:,
:,
:
1
]),
xy
],
dim
=
2
).
to
(
device
)
#
query
=
torch
.
cat
([
query
,
xy
],
dim
=
1
)
#
sift_size
=
self
.
sift_size
if
sift_size
>
0
:
xy
=
get_sift_sampled_pts
(
video
,
sift_size
,
T
,
[
H
,
W
],
device
=
device
)
sift_size
=
xy
.
shape
[
1
]
if
sift_size
>
0
:
query
=
torch
.
cat
([
query
,
xy
],
dim
=
1
)
#
num_uniformly_sampled_pts
=
self
.
sift_size
-
sift_size
if
num_uniformly_sampled_pts
>
0
:
xy2
=
get_uniformly_sampled_pts
(
num_uniformly_sampled_pts
,
video
.
shape
[
1
],
video
.
shape
[
3
:],
device
=
device
,
)
query
=
torch
.
cat
([
query
,
xy2
],
dim
=
1
)
#
if
self
.
num_uniformly_sampled_pts
>
0
:
xy
=
get_uniformly_sampled_pts
(
self
.
num_uniformly_sampled_pts
,
video
.
shape
[
1
],
video
.
shape
[
3
:],
device
=
device
,
)
query
=
torch
.
cat
([
query
,
xy
],
dim
=
1
)
#
traj_e_pind
,
vis_e_pind
,
conf_e_pind
,
__
=
self
.
model
(
video
=
video
,
queries
=
query
,
iters
=
self
.
n_iters
)
return
traj_e_pind
[...,
:
2
],
vis_e_pind
,
conf_e_pind
facebookresearch/co-tracker/cotracker/predictor.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn.functional
as
F
from
cotracker.models.core.model_utils
import
smart_cat
,
get_points_on_a_grid
from
cotracker.models.build_cotracker
import
build_cotracker
class
CoTrackerPredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
checkpoint
=
"./checkpoints/scaled_offline.pth"
,
offline
=
True
,
v2
=
False
,
window_len
=
60
,
):
super
().
__init__
()
self
.
v2
=
v2
self
.
support_grid_size
=
6
model
=
build_cotracker
(
checkpoint
,
v2
=
v2
,
offline
=
offline
,
window_len
=
window_len
,
)
self
.
interp_shape
=
model
.
model_resolution
self
.
model
=
model
self
.
model
.
eval
()
@
torch
.
no_grad
()
def
forward
(
self
,
video
,
# (B, T, 3, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries
:
torch
.
Tensor
=
None
,
segm_mask
:
torch
.
Tensor
=
None
,
# Segmentation mask of shape (B, 1, H, W)
grid_size
:
int
=
0
,
grid_query_frame
:
int
=
0
,
# only for dense and regular grid tracks
backward_tracking
:
bool
=
False
,
):
if
queries
is
None
and
grid_size
==
0
:
tracks
,
visibilities
=
self
.
_compute_dense_tracks
(
video
,
grid_query_frame
=
grid_query_frame
,
backward_tracking
=
backward_tracking
,
)
else
:
tracks
,
visibilities
=
self
.
_compute_sparse_tracks
(
video
,
queries
,
segm_mask
,
grid_size
,
add_support_grid
=
(
grid_size
==
0
or
segm_mask
is
not
None
),
grid_query_frame
=
grid_query_frame
,
backward_tracking
=
backward_tracking
,
)
return
tracks
,
visibilities
def
_compute_dense_tracks
(
self
,
video
,
grid_query_frame
,
grid_size
=
80
,
backward_tracking
=
False
):
*
_
,
H
,
W
=
video
.
shape
grid_step
=
W
//
grid_size
grid_width
=
W
//
grid_step
grid_height
=
H
//
grid_step
tracks
=
visibilities
=
None
grid_pts
=
torch
.
zeros
((
video
.
shape
[
0
],
grid_width
*
grid_height
,
3
)).
to
(
video
.
device
)
grid_pts
[:,
:,
0
]
=
grid_query_frame
for
offset
in
range
(
grid_step
*
grid_step
):
print
(
f
"step
{
offset
}
/
{
grid_step
*
grid_step
}
"
)
ox
=
offset
%
grid_step
oy
=
offset
//
grid_step
grid_pts
[:,
:,
1
]
=
(
torch
.
arange
(
grid_width
).
repeat
(
grid_height
)
*
grid_step
+
ox
)
grid_pts
[:,
:,
2
]
=
(
torch
.
arange
(
grid_height
).
repeat_interleave
(
grid_width
)
*
grid_step
+
oy
)
tracks_step
,
visibilities_step
=
self
.
_compute_sparse_tracks
(
video
=
video
,
queries
=
grid_pts
,
backward_tracking
=
backward_tracking
,
)
tracks
=
smart_cat
(
tracks
,
tracks_step
,
dim
=
2
)
visibilities
=
smart_cat
(
visibilities
,
visibilities_step
,
dim
=
2
)
return
tracks
,
visibilities
def
_compute_sparse_tracks
(
self
,
video
,
queries
,
segm_mask
=
None
,
grid_size
=
0
,
add_support_grid
=
False
,
grid_query_frame
=
0
,
backward_tracking
=
False
,
):
B
,
T
,
C
,
H
,
W
=
video
.
shape
video
=
video
.
reshape
(
B
*
T
,
C
,
H
,
W
)
video
=
F
.
interpolate
(
video
,
tuple
(
self
.
interp_shape
),
mode
=
"bilinear"
,
align_corners
=
True
)
video
=
video
.
reshape
(
B
,
T
,
3
,
self
.
interp_shape
[
0
],
self
.
interp_shape
[
1
])
if
queries
is
not
None
:
B
,
N
,
D
=
queries
.
shape
assert
D
==
3
queries
=
queries
.
clone
()
queries
[:,
:,
1
:]
*=
queries
.
new_tensor
(
[
(
self
.
interp_shape
[
1
]
-
1
)
/
(
W
-
1
),
(
self
.
interp_shape
[
0
]
-
1
)
/
(
H
-
1
),
]
)
elif
grid_size
>
0
:
grid_pts
=
get_points_on_a_grid
(
grid_size
,
self
.
interp_shape
,
device
=
video
.
device
)
if
segm_mask
is
not
None
:
segm_mask
=
F
.
interpolate
(
segm_mask
,
tuple
(
self
.
interp_shape
),
mode
=
"nearest"
)
point_mask
=
segm_mask
[
0
,
0
][
(
grid_pts
[
0
,
:,
1
]).
round
().
long
().
cpu
(),
(
grid_pts
[
0
,
:,
0
]).
round
().
long
().
cpu
(),
].
bool
()
grid_pts
=
grid_pts
[:,
point_mask
]
queries
=
torch
.
cat
(
[
torch
.
ones_like
(
grid_pts
[:,
:,
:
1
])
*
grid_query_frame
,
grid_pts
],
dim
=
2
,
).
repeat
(
B
,
1
,
1
)
if
add_support_grid
:
grid_pts
=
get_points_on_a_grid
(
self
.
support_grid_size
,
self
.
interp_shape
,
device
=
video
.
device
)
grid_pts
=
torch
.
cat
(
[
torch
.
zeros_like
(
grid_pts
[:,
:,
:
1
]),
grid_pts
],
dim
=
2
)
grid_pts
=
grid_pts
.
repeat
(
B
,
1
,
1
)
queries
=
torch
.
cat
([
queries
,
grid_pts
],
dim
=
1
)
tracks
,
visibilities
,
*
_
=
self
.
model
.
forward
(
video
=
video
,
queries
=
queries
,
iters
=
6
)
if
backward_tracking
:
tracks
,
visibilities
=
self
.
_compute_backward_tracks
(
video
,
queries
,
tracks
,
visibilities
)
if
add_support_grid
:
queries
[:,
-
self
.
support_grid_size
**
2
:,
0
]
=
T
-
1
if
add_support_grid
:
tracks
=
tracks
[:,
:,
:
-
self
.
support_grid_size
**
2
]
visibilities
=
visibilities
[:,
:,
:
-
self
.
support_grid_size
**
2
]
thr
=
0.9
visibilities
=
visibilities
>
thr
# correct query-point predictions
# see https://github.com/facebookresearch/co-tracker/issues/28
# TODO: batchify
for
i
in
range
(
len
(
queries
)):
queries_t
=
queries
[
i
,
:
tracks
.
size
(
2
),
0
].
to
(
torch
.
int64
)
arange
=
torch
.
arange
(
0
,
len
(
queries_t
))
# overwrite the predictions with the query points
tracks
[
i
,
queries_t
,
arange
]
=
queries
[
i
,
:
tracks
.
size
(
2
),
1
:]
# correct visibilities, the query points should be visible
visibilities
[
i
,
queries_t
,
arange
]
=
True
tracks
*=
tracks
.
new_tensor
(
[(
W
-
1
)
/
(
self
.
interp_shape
[
1
]
-
1
),
(
H
-
1
)
/
(
self
.
interp_shape
[
0
]
-
1
)]
)
return
tracks
,
visibilities
def
_compute_backward_tracks
(
self
,
video
,
queries
,
tracks
,
visibilities
):
inv_video
=
video
.
flip
(
1
).
clone
()
inv_queries
=
queries
.
clone
()
inv_queries
[:,
:,
0
]
=
inv_video
.
shape
[
1
]
-
inv_queries
[:,
:,
0
]
-
1
inv_tracks
,
inv_visibilities
,
*
_
=
self
.
model
(
video
=
inv_video
,
queries
=
inv_queries
,
iters
=
6
)
inv_tracks
=
inv_tracks
.
flip
(
1
)
inv_visibilities
=
inv_visibilities
.
flip
(
1
)
arange
=
torch
.
arange
(
video
.
shape
[
1
],
device
=
queries
.
device
)[
None
,
:,
None
]
mask
=
(
arange
<
queries
[:,
None
,
:,
0
]).
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
1
,
2
)
tracks
[
mask
]
=
inv_tracks
[
mask
]
visibilities
[
mask
[:,
:,
:,
0
]]
=
inv_visibilities
[
mask
[:,
:,
:,
0
]]
return
tracks
,
visibilities
class
CoTrackerOnlinePredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
checkpoint
=
"./checkpoints/scaled_online.pth"
,
offline
=
False
,
v2
=
False
,
window_len
=
16
,
):
super
().
__init__
()
self
.
v2
=
v2
self
.
support_grid_size
=
6
model
=
build_cotracker
(
checkpoint
,
v2
=
v2
,
offline
=
False
,
window_len
=
window_len
)
self
.
interp_shape
=
model
.
model_resolution
self
.
step
=
model
.
window_len
//
2
self
.
model
=
model
self
.
model
.
eval
()
@
torch
.
no_grad
()
def
forward
(
self
,
video_chunk
,
is_first_step
:
bool
=
False
,
queries
:
torch
.
Tensor
=
None
,
grid_size
:
int
=
5
,
grid_query_frame
:
int
=
0
,
add_support_grid
=
False
,
):
B
,
T
,
C
,
H
,
W
=
video_chunk
.
shape
# Initialize online video processing and save queried points
# This needs to be done before processing *each new video*
if
is_first_step
:
self
.
model
.
init_video_online_processing
()
if
queries
is
not
None
:
B
,
N
,
D
=
queries
.
shape
self
.
N
=
N
assert
D
==
3
queries
=
queries
.
clone
()
queries
[:,
:,
1
:]
*=
queries
.
new_tensor
(
[
(
self
.
interp_shape
[
1
]
-
1
)
/
(
W
-
1
),
(
self
.
interp_shape
[
0
]
-
1
)
/
(
H
-
1
),
]
)
if
add_support_grid
:
grid_pts
=
get_points_on_a_grid
(
self
.
support_grid_size
,
self
.
interp_shape
,
device
=
video_chunk
.
device
)
grid_pts
=
torch
.
cat
(
[
torch
.
zeros_like
(
grid_pts
[:,
:,
:
1
]),
grid_pts
],
dim
=
2
)
queries
=
torch
.
cat
([
queries
,
grid_pts
],
dim
=
1
)
elif
grid_size
>
0
:
grid_pts
=
get_points_on_a_grid
(
grid_size
,
self
.
interp_shape
,
device
=
video_chunk
.
device
)
self
.
N
=
grid_size
**
2
queries
=
torch
.
cat
(
[
torch
.
ones_like
(
grid_pts
[:,
:,
:
1
])
*
grid_query_frame
,
grid_pts
],
dim
=
2
,
)
self
.
queries
=
queries
return
(
None
,
None
)
video_chunk
=
video_chunk
.
reshape
(
B
*
T
,
C
,
H
,
W
)
video_chunk
=
F
.
interpolate
(
video_chunk
,
tuple
(
self
.
interp_shape
),
mode
=
"bilinear"
,
align_corners
=
True
)
video_chunk
=
video_chunk
.
reshape
(
B
,
T
,
3
,
self
.
interp_shape
[
0
],
self
.
interp_shape
[
1
]
)
if
self
.
v2
:
tracks
,
visibilities
,
__
=
self
.
model
(
video
=
video_chunk
,
queries
=
self
.
queries
,
iters
=
6
,
is_online
=
True
)
else
:
tracks
,
visibilities
,
confidence
,
__
=
self
.
model
(
video
=
video_chunk
,
queries
=
self
.
queries
,
iters
=
6
,
is_online
=
True
)
if
add_support_grid
:
tracks
=
tracks
[:,:,:
self
.
N
]
visibilities
=
visibilities
[:,:,:
self
.
N
]
if
not
self
.
v2
:
confidence
=
confidence
[:,:,:
self
.
N
]
if
not
self
.
v2
:
visibilities
=
visibilities
*
confidence
thr
=
0.6
return
(
tracks
*
tracks
.
new_tensor
(
[
(
W
-
1
)
/
(
self
.
interp_shape
[
1
]
-
1
),
(
H
-
1
)
/
(
self
.
interp_shape
[
0
]
-
1
),
]
),
visibilities
>
thr
,
)
facebookresearch/co-tracker/cotracker/utils/__init__.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
facebookresearch/co-tracker/cotracker/utils/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
0063a668
File added
facebookresearch/co-tracker/cotracker/utils/__pycache__/visualizer.cpython-310.pyc
0 → 100644
View file @
0063a668
File added
facebookresearch/co-tracker/cotracker/utils/train_utils.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
sys
import
torch
import
signal
import
socket
from
torch.utils.data
import
ConcatDataset
from
cotracker.datasets.utils
import
collate_fn
,
collate_fn_train
from
torch.utils.tensorboard
import
SummaryWriter
from
cotracker.datasets.dr_dataset
import
DynamicReplicaDataset
from
cotracker.models.evaluation_predictor
import
EvaluationPredictor
# define the handler function
# for training on a slurm cluster
def
sig_handler
(
signum
,
frame
):
print
(
"caught signal"
,
signum
)
print
(
socket
.
gethostname
(),
"USR1 signal caught."
)
# do other stuff to cleanup here
print
(
"requeuing job "
+
os
.
environ
[
"SLURM_JOB_ID"
])
os
.
system
(
"scontrol requeue "
+
os
.
environ
[
"SLURM_JOB_ID"
])
sys
.
exit
(
-
1
)
def
term_handler
(
signum
,
frame
):
print
(
"bypassing sigterm"
,
flush
=
True
)
def
get_eval_dataloader
(
dataset_root
,
ds_name
):
from
cotracker.datasets.tap_vid_datasets
import
TapVidDataset
collate_fn_local
=
collate_fn
if
ds_name
==
"dynamic_replica"
:
from
cotracker.datasets.dr_dataset
import
DynamicReplicaDataset
eval_dataset
=
DynamicReplicaDataset
(
root
=
os
.
path
.
join
(
dataset_root
,
"dynamic_replica"
),
sample_len
=
300
,
only_first_n_samples
=
1
,
rgbd_input
=
False
,
)
elif
ds_name
==
"tapvid_davis_first"
:
data_root
=
os
.
path
.
join
(
dataset_root
,
"tapvid/tapvid_davis/tapvid_davis.pkl"
)
eval_dataset
=
TapVidDataset
(
dataset_type
=
"davis"
,
data_root
=
data_root
,
queried_first
=
True
)
elif
ds_name
==
"tapvid_davis_strided"
:
data_root
=
os
.
path
.
join
(
dataset_root
,
"tapvid/tapvid_davis/tapvid_davis.pkl"
)
eval_dataset
=
TapVidDataset
(
dataset_type
=
"davis"
,
data_root
=
data_root
,
queried_first
=
False
)
elif
ds_name
==
"tapvid_kinetics_first"
:
eval_dataset
=
TapVidDataset
(
dataset_type
=
"kinetics"
,
data_root
=
os
.
path
.
join
(
dataset_root
,
"tapvid"
,
"tapvid_kinetics"
),
)
elif
ds_name
==
"tapvid_stacking"
:
eval_dataset
=
TapVidDataset
(
dataset_type
=
"stacking"
,
data_root
=
os
.
path
.
join
(
dataset_root
,
"tapvid"
,
"tapvid_rgb_stacking"
,
"tapvid_rgb_stacking.pkl"
),
)
elif
ds_name
==
"tapvid_robotap"
:
eval_dataset
=
TapVidDataset
(
dataset_type
=
"robotap"
,
data_root
=
os
.
path
.
join
(
dataset_root
,
"tapvid"
,
"tapvid_robotap"
),
)
elif
ds_name
==
"kubric"
:
from
cotracker.datasets.kubric_movif_dataset
import
KubricMovifDataset
eval_dataset
=
KubricMovifDataset
(
data_root
=
os
.
path
.
join
(
args
.
dataset_root
,
"kubric/kubric_movi_f_120_frames_dense/movi_f"
),
traj_per_sample
=
1024
,
use_augs
=
False
,
split
=
"valid"
,
sample_vis_1st_frame
=
True
,
)
collate_fn_local
=
collate_fn_train
eval_dataloader_dr
=
torch
.
utils
.
data
.
DataLoader
(
eval_dataset
,
batch_size
=
1
,
shuffle
=
False
,
num_workers
=
1
,
collate_fn
=
collate_fn_local
,
)
return
eval_dataloader_dr
def
get_train_dataset
(
args
):
dataset
=
None
if
"kubric"
in
args
.
train_datasets
:
from
cotracker.datasets
import
kubric_movif_dataset
kubric
=
kubric_movif_dataset
.
KubricMovifDataset
(
data_root
=
os
.
path
.
join
(
args
.
dataset_root
,
"kubric/kubric_movi_f_120_frames_dense/movi_f"
),
crop_size
=
args
.
crop_size
,
seq_len
=
args
.
sequence_len
,
traj_per_sample
=
args
.
traj_per_sample
,
sample_vis_last_frame
=
args
.
query_sampling_method
is
not
None
and
(
"random"
in
args
.
query_sampling_method
),
use_augs
=
not
args
.
dont_use_augs
,
random_seq_len
=
args
.
random_seq_len
,
random_frame_rate
=
args
.
random_frame_rate
,
random_number_traj
=
args
.
random_number_traj
,
)
if
dataset
is
None
:
dataset
=
ConcatDataset
(
4
*
[
kubric
])
else
:
dataset
=
ConcatDataset
(
4
*
[
kubric
]
+
[
dataset
])
print
(
"add kubric to train"
,
len
(
dataset
))
if
"dr"
in
args
.
train_datasets
:
dr
=
DynamicReplicaDataset
(
root
=
os
.
path
.
join
(
args
.
dataset_root
,
"dynamic_replica"
),
sample_len
=
args
.
sequence_len
,
split
=
"train"
,
traj_per_sample
=
args
.
traj_per_sample
,
crop_size
=
args
.
crop_size
,
)
if
dataset
is
None
:
dataset
=
dr
else
:
dataset
=
ConcatDataset
([
dr
]
+
[
dataset
])
return
dataset
def
run_test_eval
(
evaluator
,
model
,
dataloaders
,
writer
,
step
,
query_random
=
False
):
model
.
eval
()
for
ds_name
,
dataloader
in
dataloaders
:
visualize_every
=
1
grid_size
=
5
num_uniformly_sampled_pts
=
0
if
ds_name
==
"dynamic_replica"
:
visualize_every
=
8
grid_size
=
0
elif
ds_name
==
"kubric"
:
visualize_every
=
5
grid_size
=
0
elif
"davis"
in
ds_name
or
"tapvid_stacking"
in
ds_name
:
visualize_every
=
5
elif
"robotap"
in
ds_name
:
visualize_every
=
20
elif
"kinetics"
in
ds_name
:
visualize_every
=
50
if
query_random
:
grid_size
=
0
num_uniformly_sampled_pts
=
100
predictor
=
EvaluationPredictor
(
model
.
module
.
module
,
grid_size
=
grid_size
,
local_grid_size
=
0
,
single_point
=
False
,
num_uniformly_sampled_pts
=
num_uniformly_sampled_pts
,
n_iters
=
6
,
)
if
torch
.
cuda
.
is_available
():
predictor
.
model
=
predictor
.
model
.
cuda
()
metrics
=
evaluator
.
evaluate_sequence
(
model
=
predictor
,
test_dataloader
=
dataloader
,
dataset_name
=
ds_name
,
train_mode
=
True
,
writer
=
writer
,
step
=
step
,
visualize_every
=
visualize_every
,
)
if
ds_name
==
"dynamic_replica"
or
ds_name
==
"kubric"
:
metrics
=
{
f
"
{
ds_name
}
_avg_
{
k
}
"
:
v
for
k
,
v
in
metrics
[
"avg"
].
items
()
if
not
(
"1"
in
k
or
"2"
in
k
or
"4"
in
k
or
"8"
in
k
)
}
if
"tapvid"
in
ds_name
:
metrics
=
{
f
"
{
ds_name
}
_avg_OA"
:
metrics
[
"avg"
][
"occlusion_accuracy"
],
f
"
{
ds_name
}
_avg_delta"
:
metrics
[
"avg"
][
"average_pts_within_thresh"
],
f
"
{
ds_name
}
_avg_Jaccard"
:
metrics
[
"avg"
][
"average_jaccard"
],
}
writer
.
add_scalars
(
f
"Eval_
{
ds_name
}
"
,
metrics
,
step
)
class
Logger
:
SUM_FREQ
=
100
def
__init__
(
self
,
model
,
scheduler
,
ckpt_path
):
self
.
model
=
model
self
.
scheduler
=
scheduler
self
.
ckpt_path
=
ckpt_path
self
.
total_steps
=
0
self
.
running_loss
=
{}
self
.
writer
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
ckpt_path
,
"runs"
))
def
_print_training_status
(
self
):
metrics_data
=
[
self
.
running_loss
[
k
]
/
Logger
.
SUM_FREQ
for
k
in
sorted
(
self
.
running_loss
.
keys
())
]
training_str
=
"[{:6d}] "
.
format
(
self
.
total_steps
+
1
)
metrics_str
=
(
"{:10.4f}, "
*
len
(
metrics_data
)).
format
(
*
metrics_data
)
# print the training status
logging
.
info
(
f
"Training Metrics (
{
self
.
total_steps
}
):
{
training_str
+
metrics_str
}
"
)
if
self
.
writer
is
None
:
self
.
writer
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
ckpt_path
,
"runs"
))
for
k
in
self
.
running_loss
:
self
.
writer
.
add_scalar
(
k
,
self
.
running_loss
[
k
]
/
Logger
.
SUM_FREQ
,
self
.
total_steps
)
self
.
running_loss
[
k
]
=
0.0
def
push
(
self
,
metrics
,
task
):
self
.
total_steps
+=
1
for
key
in
metrics
:
task_key
=
str
(
key
)
+
"_"
+
task
if
task_key
not
in
self
.
running_loss
:
self
.
running_loss
[
task_key
]
=
0.0
self
.
running_loss
[
task_key
]
+=
metrics
[
key
]
if
self
.
total_steps
%
Logger
.
SUM_FREQ
==
Logger
.
SUM_FREQ
-
1
:
self
.
_print_training_status
()
self
.
running_loss
=
{}
def
write_dict
(
self
,
results
):
if
self
.
writer
is
None
:
self
.
writer
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
self
.
ckpt_path
,
"runs"
))
for
key
in
results
:
self
.
writer
.
add_scalar
(
key
,
results
[
key
],
self
.
total_steps
)
def
close
(
self
):
self
.
writer
.
close
()
facebookresearch/co-tracker/cotracker/utils/visualizer.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
numpy
as
np
import
imageio
import
torch
from
matplotlib
import
cm
import
torch.nn.functional
as
F
import
torchvision.transforms
as
transforms
import
matplotlib.pyplot
as
plt
from
PIL
import
Image
,
ImageDraw
def
read_video_from_path
(
path
):
try
:
reader
=
imageio
.
get_reader
(
path
)
except
Exception
as
e
:
print
(
"Error opening video file: "
,
e
)
return
None
frames
=
[]
for
i
,
im
in
enumerate
(
reader
):
frames
.
append
(
np
.
array
(
im
))
return
np
.
stack
(
frames
)
def
draw_circle
(
rgb
,
coord
,
radius
,
color
=
(
255
,
0
,
0
),
visible
=
True
,
color_alpha
=
None
):
# Create a draw object
draw
=
ImageDraw
.
Draw
(
rgb
)
# Calculate the bounding box of the circle
left_up_point
=
(
coord
[
0
]
-
radius
,
coord
[
1
]
-
radius
)
right_down_point
=
(
coord
[
0
]
+
radius
,
coord
[
1
]
+
radius
)
# Draw the circle
color
=
tuple
(
list
(
color
)
+
[
color_alpha
if
color_alpha
is
not
None
else
255
])
draw
.
ellipse
(
[
left_up_point
,
right_down_point
],
fill
=
tuple
(
color
)
if
visible
else
None
,
outline
=
tuple
(
color
),
)
return
rgb
def
draw_line
(
rgb
,
coord_y
,
coord_x
,
color
,
linewidth
):
draw
=
ImageDraw
.
Draw
(
rgb
)
draw
.
line
(
(
coord_y
[
0
],
coord_y
[
1
],
coord_x
[
0
],
coord_x
[
1
]),
fill
=
tuple
(
color
),
width
=
linewidth
,
)
return
rgb
def
add_weighted
(
rgb
,
alpha
,
original
,
beta
,
gamma
):
return
(
rgb
*
alpha
+
original
*
beta
+
gamma
).
astype
(
"uint8"
)
class
Visualizer
:
def
__init__
(
self
,
save_dir
:
str
=
"./results"
,
grayscale
:
bool
=
False
,
pad_value
:
int
=
0
,
fps
:
int
=
10
,
mode
:
str
=
"rainbow"
,
# 'cool', 'optical_flow'
linewidth
:
int
=
2
,
show_first_frame
:
int
=
10
,
tracks_leave_trace
:
int
=
0
,
# -1 for infinite
):
self
.
mode
=
mode
self
.
save_dir
=
save_dir
if
mode
==
"rainbow"
:
self
.
color_map
=
cm
.
get_cmap
(
"gist_rainbow"
)
elif
mode
==
"cool"
:
self
.
color_map
=
cm
.
get_cmap
(
mode
)
self
.
show_first_frame
=
show_first_frame
self
.
grayscale
=
grayscale
self
.
tracks_leave_trace
=
tracks_leave_trace
self
.
pad_value
=
pad_value
self
.
linewidth
=
linewidth
self
.
fps
=
fps
def
visualize
(
self
,
video
:
torch
.
Tensor
,
# (B,T,C,H,W)
tracks
:
torch
.
Tensor
,
# (B,T,N,2)
visibility
:
torch
.
Tensor
=
None
,
# (B, T, N, 1) bool
gt_tracks
:
torch
.
Tensor
=
None
,
# (B,T,N,2)
segm_mask
:
torch
.
Tensor
=
None
,
# (B,1,H,W)
filename
:
str
=
"video"
,
writer
=
None
,
# tensorboard Summary Writer, used for visualization during training
step
:
int
=
0
,
query_frame
=
0
,
save_video
:
bool
=
True
,
compensate_for_camera_motion
:
bool
=
False
,
opacity
:
float
=
1.0
,
):
if
compensate_for_camera_motion
:
assert
segm_mask
is
not
None
if
segm_mask
is
not
None
:
coords
=
tracks
[
0
,
query_frame
].
round
().
long
()
segm_mask
=
segm_mask
[
0
,
query_frame
][
coords
[:,
1
],
coords
[:,
0
]].
long
()
video
=
F
.
pad
(
video
,
(
self
.
pad_value
,
self
.
pad_value
,
self
.
pad_value
,
self
.
pad_value
),
"constant"
,
255
,
)
color_alpha
=
int
(
opacity
*
255
)
tracks
=
tracks
+
self
.
pad_value
if
self
.
grayscale
:
transform
=
transforms
.
Grayscale
()
video
=
transform
(
video
)
video
=
video
.
repeat
(
1
,
1
,
3
,
1
,
1
)
res_video
=
self
.
draw_tracks_on_video
(
video
=
video
,
tracks
=
tracks
,
visibility
=
visibility
,
segm_mask
=
segm_mask
,
gt_tracks
=
gt_tracks
,
query_frame
=
query_frame
,
compensate_for_camera_motion
=
compensate_for_camera_motion
,
color_alpha
=
color_alpha
,
)
if
save_video
:
self
.
save_video
(
res_video
,
filename
=
filename
,
writer
=
writer
,
step
=
step
)
return
res_video
def
save_video
(
self
,
video
,
filename
,
writer
=
None
,
step
=
0
):
if
writer
is
not
None
:
writer
.
add_video
(
filename
,
video
.
to
(
torch
.
uint8
),
global_step
=
step
,
fps
=
self
.
fps
,
)
else
:
os
.
makedirs
(
self
.
save_dir
,
exist_ok
=
True
)
wide_list
=
list
(
video
.
unbind
(
1
))
wide_list
=
[
wide
[
0
].
permute
(
1
,
2
,
0
).
cpu
().
numpy
()
for
wide
in
wide_list
]
# Prepare the video file path
save_path
=
os
.
path
.
join
(
self
.
save_dir
,
f
"
{
filename
}
.mp4"
)
# Create a writer object
video_writer
=
imageio
.
get_writer
(
save_path
,
fps
=
self
.
fps
)
# Write frames to the video file
for
frame
in
wide_list
[
2
:
-
1
]:
video_writer
.
append_data
(
frame
)
video_writer
.
close
()
print
(
f
"Video saved to
{
save_path
}
"
)
def
draw_tracks_on_video
(
self
,
video
:
torch
.
Tensor
,
tracks
:
torch
.
Tensor
,
visibility
:
torch
.
Tensor
=
None
,
segm_mask
:
torch
.
Tensor
=
None
,
gt_tracks
=
None
,
query_frame
=
0
,
compensate_for_camera_motion
=
False
,
color_alpha
:
int
=
255
,
):
B
,
T
,
C
,
H
,
W
=
video
.
shape
_
,
_
,
N
,
D
=
tracks
.
shape
assert
D
==
2
assert
C
==
3
video
=
video
[
0
].
permute
(
0
,
2
,
3
,
1
).
byte
().
detach
().
cpu
().
numpy
()
# S, H, W, C
tracks
=
tracks
[
0
].
long
().
detach
().
cpu
().
numpy
()
# S, N, 2
if
gt_tracks
is
not
None
:
gt_tracks
=
gt_tracks
[
0
].
detach
().
cpu
().
numpy
()
res_video
=
[]
# process input video
for
rgb
in
video
:
res_video
.
append
(
rgb
.
copy
())
vector_colors
=
np
.
zeros
((
T
,
N
,
3
))
if
self
.
mode
==
"optical_flow"
:
import
flow_vis
vector_colors
=
flow_vis
.
flow_to_color
(
tracks
-
tracks
[
query_frame
][
None
])
elif
segm_mask
is
None
:
if
self
.
mode
==
"rainbow"
:
y_min
,
y_max
=
(
tracks
[
query_frame
,
:,
1
].
min
(),
tracks
[
query_frame
,
:,
1
].
max
(),
)
norm
=
plt
.
Normalize
(
y_min
,
y_max
)
for
n
in
range
(
N
):
if
isinstance
(
query_frame
,
torch
.
Tensor
):
query_frame_
=
query_frame
[
n
]
else
:
query_frame_
=
query_frame
color
=
self
.
color_map
(
norm
(
tracks
[
query_frame_
,
n
,
1
]))
color
=
np
.
array
(
color
[:
3
])[
None
]
*
255
vector_colors
[:,
n
]
=
np
.
repeat
(
color
,
T
,
axis
=
0
)
else
:
# color changes with time
for
t
in
range
(
T
):
color
=
np
.
array
(
self
.
color_map
(
t
/
T
)[:
3
])[
None
]
*
255
vector_colors
[
t
]
=
np
.
repeat
(
color
,
N
,
axis
=
0
)
else
:
if
self
.
mode
==
"rainbow"
:
vector_colors
[:,
segm_mask
<=
0
,
:]
=
255
y_min
,
y_max
=
(
tracks
[
0
,
segm_mask
>
0
,
1
].
min
(),
tracks
[
0
,
segm_mask
>
0
,
1
].
max
(),
)
norm
=
plt
.
Normalize
(
y_min
,
y_max
)
for
n
in
range
(
N
):
if
segm_mask
[
n
]
>
0
:
color
=
self
.
color_map
(
norm
(
tracks
[
0
,
n
,
1
]))
color
=
np
.
array
(
color
[:
3
])[
None
]
*
255
vector_colors
[:,
n
]
=
np
.
repeat
(
color
,
T
,
axis
=
0
)
else
:
# color changes with segm class
segm_mask
=
segm_mask
.
cpu
()
color
=
np
.
zeros
((
segm_mask
.
shape
[
0
],
3
),
dtype
=
np
.
float32
)
color
[
segm_mask
>
0
]
=
np
.
array
(
self
.
color_map
(
1.0
)[:
3
])
*
255.0
color
[
segm_mask
<=
0
]
=
np
.
array
(
self
.
color_map
(
0.0
)[:
3
])
*
255.0
vector_colors
=
np
.
repeat
(
color
[
None
],
T
,
axis
=
0
)
# draw tracks
if
self
.
tracks_leave_trace
!=
0
:
for
t
in
range
(
query_frame
+
1
,
T
):
first_ind
=
(
max
(
0
,
t
-
self
.
tracks_leave_trace
)
if
self
.
tracks_leave_trace
>=
0
else
0
)
curr_tracks
=
tracks
[
first_ind
:
t
+
1
]
curr_colors
=
vector_colors
[
first_ind
:
t
+
1
]
if
compensate_for_camera_motion
:
diff
=
(
tracks
[
first_ind
:
t
+
1
,
segm_mask
<=
0
]
-
tracks
[
t
:
t
+
1
,
segm_mask
<=
0
]
).
mean
(
1
)[:,
None
]
curr_tracks
=
curr_tracks
-
diff
curr_tracks
=
curr_tracks
[:,
segm_mask
>
0
]
curr_colors
=
curr_colors
[:,
segm_mask
>
0
]
res_video
[
t
]
=
self
.
_draw_pred_tracks
(
res_video
[
t
],
curr_tracks
,
curr_colors
,
)
if
gt_tracks
is
not
None
:
res_video
[
t
]
=
self
.
_draw_gt_tracks
(
res_video
[
t
],
gt_tracks
[
first_ind
:
t
+
1
]
)
# draw points
for
t
in
range
(
T
):
img
=
Image
.
fromarray
(
np
.
uint8
(
res_video
[
t
]))
for
i
in
range
(
N
):
coord
=
(
tracks
[
t
,
i
,
0
],
tracks
[
t
,
i
,
1
])
visibile
=
True
if
visibility
is
not
None
:
visibile
=
visibility
[
0
,
t
,
i
]
if
coord
[
0
]
!=
0
and
coord
[
1
]
!=
0
:
if
not
compensate_for_camera_motion
or
(
compensate_for_camera_motion
and
segm_mask
[
i
]
>
0
):
img
=
draw_circle
(
img
,
coord
=
coord
,
radius
=
int
(
self
.
linewidth
*
2
),
color
=
vector_colors
[
t
,
i
].
astype
(
int
),
visible
=
visibile
,
color_alpha
=
color_alpha
,
)
res_video
[
t
]
=
np
.
array
(
img
)
# construct the final rgb sequence
if
self
.
show_first_frame
>
0
:
res_video
=
[
res_video
[
0
]]
*
self
.
show_first_frame
+
res_video
[
1
:]
return
torch
.
from_numpy
(
np
.
stack
(
res_video
)).
permute
(
0
,
3
,
1
,
2
)[
None
].
byte
()
def
_draw_pred_tracks
(
self
,
rgb
:
np
.
ndarray
,
# H x W x 3
tracks
:
np
.
ndarray
,
# T x 2
vector_colors
:
np
.
ndarray
,
alpha
:
float
=
0.5
,
):
T
,
N
,
_
=
tracks
.
shape
rgb
=
Image
.
fromarray
(
np
.
uint8
(
rgb
))
for
s
in
range
(
T
-
1
):
vector_color
=
vector_colors
[
s
]
original
=
rgb
.
copy
()
alpha
=
(
s
/
T
)
**
2
for
i
in
range
(
N
):
coord_y
=
(
int
(
tracks
[
s
,
i
,
0
]),
int
(
tracks
[
s
,
i
,
1
]))
coord_x
=
(
int
(
tracks
[
s
+
1
,
i
,
0
]),
int
(
tracks
[
s
+
1
,
i
,
1
]))
if
coord_y
[
0
]
!=
0
and
coord_y
[
1
]
!=
0
:
rgb
=
draw_line
(
rgb
,
coord_y
,
coord_x
,
vector_color
[
i
].
astype
(
int
),
self
.
linewidth
,
)
if
self
.
tracks_leave_trace
>
0
:
rgb
=
Image
.
fromarray
(
np
.
uint8
(
add_weighted
(
np
.
array
(
rgb
),
alpha
,
np
.
array
(
original
),
1
-
alpha
,
0
)
)
)
rgb
=
np
.
array
(
rgb
)
return
rgb
def
_draw_gt_tracks
(
self
,
rgb
:
np
.
ndarray
,
# H x W x 3,
gt_tracks
:
np
.
ndarray
,
# T x 2
):
T
,
N
,
_
=
gt_tracks
.
shape
color
=
np
.
array
((
211
,
0
,
0
))
rgb
=
Image
.
fromarray
(
np
.
uint8
(
rgb
))
for
t
in
range
(
T
):
for
i
in
range
(
N
):
gt_tracks
=
gt_tracks
[
t
][
i
]
# draw a red cross
if
gt_tracks
[
0
]
>
0
and
gt_tracks
[
1
]
>
0
:
length
=
self
.
linewidth
*
3
coord_y
=
(
int
(
gt_tracks
[
0
])
+
length
,
int
(
gt_tracks
[
1
])
+
length
)
coord_x
=
(
int
(
gt_tracks
[
0
])
-
length
,
int
(
gt_tracks
[
1
])
-
length
)
rgb
=
draw_line
(
rgb
,
coord_y
,
coord_x
,
color
,
self
.
linewidth
,
)
coord_y
=
(
int
(
gt_tracks
[
0
])
-
length
,
int
(
gt_tracks
[
1
])
+
length
)
coord_x
=
(
int
(
gt_tracks
[
0
])
+
length
,
int
(
gt_tracks
[
1
])
-
length
)
rgb
=
draw_line
(
rgb
,
coord_y
,
coord_x
,
color
,
self
.
linewidth
,
)
rgb
=
np
.
array
(
rgb
)
return
rgb
facebookresearch/co-tracker/cotracker/version.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
__version__
=
"3.0.0"
facebookresearch/co-tracker/demo.py
0 → 100644
View file @
0063a668
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
torch
import
argparse
import
numpy
as
np
from
PIL
import
Image
from
cotracker.utils.visualizer
import
Visualizer
,
read_video_from_path
from
cotracker.predictor
import
CoTrackerPredictor
DEFAULT_DEVICE
=
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"mps"
if
torch
.
backends
.
mps
.
is_available
()
else
"cpu"
)
# if DEFAULT_DEVICE == "mps":
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--video_path"
,
default
=
"./assets/apple.mp4"
,
help
=
"path to a video"
,
)
parser
.
add_argument
(
"--mask_path"
,
default
=
"./assets/apple_mask.png"
,
help
=
"path to a segmentation mask"
,
)
parser
.
add_argument
(
"--checkpoint"
,
# default="./checkpoints/cotracker.pth",
default
=
None
,
help
=
"CoTracker model parameters"
,
)
parser
.
add_argument
(
"--grid_size"
,
type
=
int
,
default
=
10
,
help
=
"Regular grid size"
)
parser
.
add_argument
(
"--grid_query_frame"
,
type
=
int
,
default
=
0
,
help
=
"Compute dense and grid tracks starting from this frame"
,
)
parser
.
add_argument
(
"--backward_tracking"
,
action
=
"store_true"
,
help
=
"Compute tracks in both directions, not only forward"
,
)
parser
.
add_argument
(
"--use_v2_model"
,
action
=
"store_true"
,
help
=
"Pass it if you wish to use CoTracker2, CoTracker++ is the default now"
,
)
parser
.
add_argument
(
"--offline"
,
action
=
"store_true"
,
help
=
"Pass it if you would like to use the offline model, in case of online don't pass it"
,
)
args
=
parser
.
parse_args
()
# load the input video frame by frame
video
=
read_video_from_path
(
args
.
video_path
)
video
=
torch
.
from_numpy
(
video
).
permute
(
0
,
3
,
1
,
2
)[
None
].
float
()
segm_mask
=
np
.
array
(
Image
.
open
(
os
.
path
.
join
(
args
.
mask_path
)))
segm_mask
=
torch
.
from_numpy
(
segm_mask
)[
None
,
None
]
if
args
.
checkpoint
is
not
None
:
if
args
.
use_v2_model
:
model
=
CoTrackerPredictor
(
checkpoint
=
args
.
checkpoint
,
v2
=
args
.
use_v2_model
)
else
:
if
args
.
offline
:
window_len
=
60
else
:
window_len
=
16
model
=
CoTrackerPredictor
(
checkpoint
=
args
.
checkpoint
,
v2
=
args
.
use_v2_model
,
offline
=
args
.
offline
,
window_len
=
window_len
,
)
else
:
model
=
torch
.
hub
.
load
(
"facebookresearch/co-tracker"
,
"cotracker3_offline"
)
model
=
model
.
to
(
DEFAULT_DEVICE
)
video
=
video
.
to
(
DEFAULT_DEVICE
)
pred_tracks
,
pred_visibility
=
model
(
video
,
grid_size
=
args
.
grid_size
,
grid_query_frame
=
args
.
grid_query_frame
,
backward_tracking
=
args
.
backward_tracking
,
# segm_mask=segm_mask
)
print
(
"computed"
)
# save a video with predicted tracks
seq_name
=
args
.
video_path
.
split
(
"/"
)[
-
1
]
vis
=
Visualizer
(
save_dir
=
"./saved_videos"
,
pad_value
=
120
,
linewidth
=
3
)
vis
.
visualize
(
video
,
pred_tracks
,
pred_visibility
,
query_frame
=
0
if
args
.
backward_tracking
else
args
.
grid_query_frame
,
)
facebookresearch/co-tracker/docs/Makefile
0 → 100644
View file @
0063a668
SPHINXOPTS
?=
SPHINXBUILD
?=
sphinx-build
SOURCEDIR
=
source
BUILDDIR
=
_build
O
=
-a
help
:
@
$(SPHINXBUILD)
-M
help
"
$(SOURCEDIR)
"
"
$(BUILDDIR)
"
$(SPHINXOPTS)
$(O)
.PHONY
:
help Makefile
%
:
Makefile
@
$(SPHINXBUILD)
-M
$@
"
$(SOURCEDIR)
"
"
$(BUILDDIR)
"
$(SPHINXOPTS)
$(O)
\ No newline at end of file
facebookresearch/co-tracker/docs/source/apis/models.rst
0 → 100644
View file @
0063a668
Models
======
CoTracker models:
.. currentmodule:: cotracker.models
Model Utils
-----------
.. automodule:: cotracker.models.core.model_utils
:members:
:undoc-members:
:show-inheritance:
\ No newline at end of file
Prev
1
…
8
9
10
11
12
13
14
15
16
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment