Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Callable
import collections
from torch import Tensor
from itertools import repeat
from cotracker.models.core.model_utils import bilinear_sampler
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = "instance"
self.in_planes = output_dim // 2
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
def _bilinear_intepolate(x):
return F.interpolate(
x,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
a = _bilinear_intepolate(a)
b = _bilinear_intepolate(b)
c = _bilinear_intepolate(c)
d = _bilinear_intepolate(d)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
return x
class EfficientCorrBlock:
def __init__(
self,
fmaps,
num_levels=4,
radius=4,
padding_mode="zeros",
):
B, S, C, H, W = fmaps.shape
self.padding_mode = padding_mode
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords, target):
r = self.radius
device = coords.device
B, S, N, D = coords.shape
assert D == 2
target = target.permute(0, 1, 3, 2).unsqueeze(-1)
out_pyramid = []
for i in range(self.num_levels):
pyramid = self.fmaps_pyramid[i]
C, H, W = pyramid.shape[2:]
centroid_lvl = (
torch.cat(
[torch.zeros_like(coords[..., :1], device=device), coords], dim=-1
).reshape(B * S, N, 1, 1, 3)
/ 2**i
)
dx = torch.linspace(-r, r, 2 * r + 1, device=device)
dy = torch.linspace(-r, r, 2 * r + 1, device=device)
xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
zgrid = torch.zeros_like(xgrid, device=device)
delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)
coords_lvl = centroid_lvl + delta_lvl
pyramid_sample = bilinear_sampler(
pyramid.reshape(B * S, C, 1, H, W), coords_lvl
)
corr = torch.sum(target * pyramid_sample.reshape(B, S, C, N, -1), dim=2)
corr = corr / torch.sqrt(torch.tensor(C).float())
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
return out
class CorrBlock:
def __init__(
self,
fmaps,
num_levels=4,
radius=4,
multiple_track_feats=False,
padding_mode="zeros",
):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.padding_mode = padding_mode
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.multiple_track_feats = multiple_track_feats
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
*_, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(
corrs.reshape(B * S * N, 1, H, W),
coords_lvl,
padding_mode=self.padding_mode,
)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
return out
def corr(self, targets):
B, S, N, C = targets.shape
if self.multiple_track_feats:
targets_split = targets.split(C // self.num_levels, dim=-1)
B, S, N, C = targets_split[0].shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for i, fmaps in enumerate(self.fmaps_pyramid):
*_, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
if self.multiple_track_feats:
fmap1 = targets_split[i]
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class Attention(nn.Module):
def __init__(
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
):
super().__init__()
inner_dim = dim_head * num_heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, C = x.shape
h = self.heads
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
N2 = context.shape[1]
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
sim = (q @ k.transpose(-2, -1)) * self.scale
if attn_bias is not None:
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
return self.to_out(x)
class AttnBlock(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
attn_class: Callable[..., nn.Module] = Attention,
mlp_ratio=4.0,
**block_kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = attn_class(
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, mask=None):
attn_bias = mask
if mask is not None:
mask = (
(mask[:, None] * mask[:, :, None])
.unsqueeze(1)
.expand(-1, self.attn.num_heads, -1, -1)
)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
x = x + self.mlp(self.norm2(x))
return x
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from cotracker.models.core.model_utils import sample_features4d, sample_features5d
from cotracker.models.core.embeddings import (
get_2d_embedding,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
)
from cotracker.models.core.cotracker.blocks import (
Mlp,
BasicEncoder,
AttnBlock,
CorrBlock,
Attention,
)
torch.manual_seed(0)
class CoTracker2(nn.Module):
def __init__(
self,
window_len=8,
stride=4,
add_space_attn=True,
num_virtual_tracks=64,
model_resolution=(384, 512),
):
super(CoTracker2, self).__init__()
self.window_len = window_len
self.stride = stride
self.hidden_dim = 256
self.latent_dim = 128
self.add_space_attn = add_space_attn
self.fnet = BasicEncoder(output_dim=self.latent_dim)
self.num_virtual_tracks = num_virtual_tracks
self.model_resolution = model_resolution
self.input_dim = 456
self.updateformer = EfficientUpdateFormer(
space_depth=6,
time_depth=6,
input_dim=self.input_dim,
hidden_size=384,
output_dim=self.latent_dim + 2,
mlp_ratio=4.0,
add_space_attn=add_space_attn,
num_virtual_tracks=num_virtual_tracks,
)
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(
1, window_len, 1
)
self.register_buffer(
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
)
self.register_buffer(
"pos_emb",
get_2d_sincos_pos_embed(
embed_dim=self.input_dim,
grid_size=(
model_resolution[0] // stride,
model_resolution[1] // stride,
),
),
)
self.norm = nn.GroupNorm(1, self.latent_dim)
self.track_feat_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.vis_predictor = nn.Sequential(
nn.Linear(self.latent_dim, 1),
)
def forward_window(
self,
fmaps,
coords,
track_feat=None,
vis=None,
track_mask=None,
attention_mask=None,
iters=4,
):
# B = batch size
# S = number of frames in the window)
# N = number of tracks
# C = channels of a point feature vector
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# track_feat = B S N C
# vis = B S N 1
# track_mask = B S N 1
# attention_mask = B S N
B, S_init, N, __ = track_mask.shape
B, S, *_ = fmaps.shape
track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
track_mask_vis = (
torch.cat([track_mask, vis], dim=-1)
.permute(0, 2, 1, 3)
.reshape(B * N, S, 2)
)
corr_block = CorrBlock(
fmaps,
num_levels=4,
radius=3,
padding_mode="border",
)
sampled_pos_emb = (
sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
.reshape(B * N, self.input_dim)
.unsqueeze(1)
) # B E N -> (B N) 1 E
coord_preds = []
for __ in range(iters):
coords = coords.detach() # B S N 2
corr_block.corr(track_feat)
# Sample correlation features around each point
fcorrs = corr_block.sample(coords) # (B N) S LRR
# Get the flow embeddings
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(
B * N, S, self.latent_dim
)
transformer_input = torch.cat(
[flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2
)
x = transformer_input + sampled_pos_emb + self.time_emb
x = x.view(B, N, S, -1) # (B N) S D -> B N S D
delta = self.updateformer(
x,
attention_mask.reshape(B * S, N), # B S N -> (B S) N
)
delta_coords = delta[..., :2].permute(0, 2, 1, 3)
coords = coords + delta_coords
coord_preds.append(coords * self.stride)
delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(
B * N * S, self.latent_dim
)
track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # (B N S) C -> B S N C
vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
return coord_preds, vis_pred
def get_track_feat(self, fmaps, queried_frames, queried_coords):
sample_frames = queried_frames[:, None, :, None]
sample_coords = torch.cat(
[
sample_frames,
queried_coords[:, None],
],
dim=-1,
)
sample_track_feats = sample_features5d(fmaps, sample_coords)
return sample_track_feats
def init_video_online_processing(self):
self.online_ind = 0
self.online_track_feat = None
self.online_coords_predicted = None
self.online_vis_predicted = None
def forward(self, video, queries, iters=4, is_train=False, is_online=False):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B, T, C, H, W = video.shape
B, N, __ = queries.shape
S = self.window_len
device = queries.device
# B = batch size
# S = number of frames in the window of the padded video
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# video = B T C H W
# queries = B N 3
# coords_init = B S N 2
# vis_init = B S N 1
assert S >= 2 # A tracker needs at least two frames to track something
if is_online:
assert T <= S, "Online mode: video chunk must be <= window size."
assert (
self.online_ind is not None
), "Call model.init_video_online_processing() first."
assert not is_train, "Training not supported in online mode."
step = S // 2 # How much the sliding window moves at every step
video = 2 * (video / 255.0) - 1.0
# The first channel is the frame number
# The rest are the coordinates of points we want to track
queried_frames = queries[:, :, 0].long()
queried_coords = queries[..., 1:]
queried_coords = queried_coords / self.stride
# We store our predictions here
coords_predicted = torch.zeros((B, T, N, 2), device=device)
vis_predicted = torch.zeros((B, T, N), device=device)
if is_online:
if self.online_coords_predicted is None:
# Init online predictions with zeros
self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted
else:
# Pad online predictions with zeros for the current window
pad = min(step, T - step)
coords_predicted = F.pad(
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
)
vis_predicted = F.pad(
self.online_vis_predicted, (0, 0, 0, pad), "constant"
)
all_coords_predictions, all_vis_predictions = [], []
# Pad the video so that an integer number of sliding windows fit into it
# TODO: we may drop this requirement because the transformer should not care
# TODO: pad the features instead of the video
pad = (
S - T if is_online else (S - T % S) % S
) # We don't want to pad if T % S == 0
video = video.reshape(B, 1, T, C * H * W)
video_pad = video[:, :, -1:].repeat(1, 1, pad, 1)
video = torch.cat([video, video_pad], dim=2)
# Compute convolutional features for the video or for the current chunk in case of online mode
fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
B, -1, self.latent_dim, H // self.stride, W // self.stride
)
# We compute track features
track_feat = self.get_track_feat(
fmaps,
queried_frames - self.online_ind if is_online else queried_frames,
queried_coords,
).repeat(1, S, 1, 1)
if is_online:
# We update track features for the current window
sample_frames = queried_frames[:, None, :, None] # B 1 N 1
left = 0 if self.online_ind == 0 else self.online_ind + step
right = self.online_ind + S
sample_mask = (sample_frames >= left) & (sample_frames < right)
if self.online_track_feat is None:
self.online_track_feat = torch.zeros_like(track_feat, device=device)
self.online_track_feat += track_feat * sample_mask
track_feat = self.online_track_feat.clone()
# We process ((num_windows - 1) * step + S) frames in total, so there are
# (ceil((T - S) / step) + 1) windows
num_windows = (T - S + step - 1) // step + 1
# We process only the current video chunk in the online mode
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
for ind in indices:
# We copy over coords and vis for tracks that are queried
# by the end of the previous window, which is ind + overlap
if ind > 0:
overlap = S - step
copy_over = (queried_frames < ind + overlap)[
:, None, :, None
] # B 1 N 1
coords_prev = torch.nn.functional.pad(
coords_predicted[:, ind : ind + overlap] / self.stride,
(0, 0, 0, 0, 0, step),
"replicate",
) # B S N 2
vis_prev = torch.nn.functional.pad(
vis_predicted[:, ind : ind + overlap, :, None].clone(),
(0, 0, 0, 0, 0, step),
"replicate",
) # B S N 1
coords_init = torch.where(
copy_over.expand_as(coords_init), coords_prev, coords_init
)
vis_init = torch.where(
copy_over.expand_as(vis_init), vis_prev, vis_init
)
# The attention mask is 1 for the spatio-temporal points within
# a track which is updated in the current window
attention_mask = (
(queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1)
) # B S N
# The track mask is 1 for the spatio-temporal points that actually
# need updating: only after begin queried, and not if contained
# in a previous window
track_mask = (
queried_frames[:, None, :, None]
<= torch.arange(ind, ind + S, device=device)[None, :, None, None]
).contiguous() # B S N 1
if ind > 0:
track_mask[:, :overlap, :, :] = False
# Predict the coordinates and visibility for the current window
coords, vis = self.forward_window(
fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
coords=coords_init,
track_feat=attention_mask.unsqueeze(-1) * track_feat,
vis=vis_init,
track_mask=track_mask,
attention_mask=attention_mask,
iters=iters,
)
S_trimmed = (
T if is_online else min(T - ind, S)
) # accounts for last window duration
coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
if is_train:
all_coords_predictions.append(
[coord[:, :S_trimmed] for coord in coords]
)
all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
if is_online:
self.online_ind += step
self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted
vis_predicted = torch.sigmoid(vis_predicted)
if is_train:
mask = (
queried_frames[:, None]
<= torch.arange(0, T, device=device)[None, :, None]
)
train_data = (all_coords_predictions, all_vis_predictions, mask)
else:
train_data = None
return coords_predicted, vis_predicted, train_data
class EfficientUpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=6,
time_depth=6,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
num_virtual_tracks=64,
add_space_attn=True,
linear_layer_for_vis_conf=False,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
if linear_layer_for_vis_conf:
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
else:
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.num_virtual_tracks = num_virtual_tracks
self.virual_tracks = nn.Parameter(
torch.randn(1, num_virtual_tracks, 1, hidden_size)
)
self.add_space_attn = add_space_attn
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
self.time_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_virtual_blocks = nn.ModuleList(
[
AttnBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_class=Attention,
)
for _ in range(space_depth)
]
)
self.space_point2virtual_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
self.space_virtual2point_blocks = nn.ModuleList(
[
CrossAttnBlock(
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
if self.linear_layer_for_vis_conf:
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
def _trunc_init(module):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(_basic_init)
def forward(self, input_tensor, mask=None, add_space_attn=True):
tokens = self.input_transform(input_tensor)
B, _, T, _ = tokens.shape
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
tokens = torch.cat([tokens, virtual_tokens], dim=1)
_, N, _, _ = tokens.shape
j = 0
layers = []
for i in range(len(self.time_blocks)):
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
time_tokens = self.time_blocks[i](time_tokens)
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
if (
add_space_attn
and hasattr(self, "space_virtual_blocks")
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
):
space_tokens = (
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
) # B N T C -> (B T) N C
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
virtual_tokens = self.space_virtual2point_blocks[j](
virtual_tokens, point_tokens, mask=mask
)
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
point_tokens = self.space_point2virtual_blocks[j](
point_tokens, virtual_tokens, mask=mask
)
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
tokens = space_tokens.view(B, T, N, -1).permute(
0, 2, 1, 3
) # (B T) N C -> B N T C
j += 1
tokens = tokens[:, : N - self.num_virtual_tracks]
flow = self.flow_head(tokens)
if self.linear_layer_for_vis_conf:
vis_conf = self.vis_conf_head(tokens)
flow = torch.cat([flow, vis_conf], dim=-1)
return flow
class CrossAttnBlock(nn.Module):
def __init__(
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm_context = nn.LayerNorm(hidden_size)
self.cross_attn = Attention(
hidden_size,
context_dim=context_dim,
num_heads=num_heads,
qkv_bias=True,
**block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, context, mask=None):
attn_bias = None
if mask is not None:
if mask.shape[1] == x.shape[1]:
mask = mask[:, None, :, None].expand(
-1, self.cross_attn.heads, -1, context.shape[1]
)
else:
mask = mask[:, None, None].expand(
-1, self.cross_attn.heads, x.shape[1], -1
)
max_neg_value = -torch.finfo(x.dtype).max
attn_bias = (~mask) * max_neg_value
x = x + self.cross_attn(
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
)
x = x + self.mlp(self.norm2(x))
return x
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from cotracker.models.core.cotracker.cotracker3_online import CoTrackerThreeBase, posenc
torch.manual_seed(0)
class CoTrackerThreeOffline(CoTrackerThreeBase):
def __init__(self, **args):
super(CoTrackerThreeOffline, self).__init__(**args)
def forward(
self,
video,
queries,
iters=4,
is_train=False,
add_space_attn=True,
fmaps_chunk_size=200,
):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B, T, C, H, W = video.shape
device = queries.device
assert H % self.stride == 0 and W % self.stride == 0
B, N, __ = queries.shape
# B = batch size
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# video = B T C H W
# queries = B N 3
# coords_init = B T N 2
# vis_init = B T N 1
assert T >= 1 # A tracker needs at least two frames to track something
video = 2 * (video / 255.0) - 1.0
dtype = video.dtype
queried_frames = queries[:, :, 0].long()
queried_coords = queries[..., 1:3]
queried_coords = queried_coords / self.stride
# We store our predictions here
all_coords_predictions, all_vis_predictions, all_confidence_predictions = (
[],
[],
[],
)
C_ = C
H4, W4 = H // self.stride, W // self.stride
# Compute convolutional features for the video or for the current chunk in case of online mode
if T > fmaps_chunk_size:
fmaps = []
for t in range(0, T, fmaps_chunk_size):
video_chunk = video[:, t : t + fmaps_chunk_size]
fmaps_chunk = self.fnet(video_chunk.reshape(-1, C_, H, W))
T_chunk = video_chunk.shape[1]
C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
else:
fmaps = self.fnet(video.reshape(-1, C_, H, W))
fmaps = fmaps.permute(0, 2, 3, 1)
fmaps = fmaps / torch.sqrt(
torch.maximum(
torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
torch.tensor(1e-12, device=fmaps.device),
)
)
fmaps = fmaps.permute(0, 3, 1, 2).reshape(
B, -1, self.latent_dim, H // self.stride, W // self.stride
)
fmaps = fmaps.to(dtype)
# We compute track features
fmaps_pyramid = []
track_feat_pyramid = []
track_feat_support_pyramid = []
fmaps_pyramid.append(fmaps)
for i in range(self.corr_levels - 1):
fmaps_ = fmaps.reshape(
B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
fmaps = fmaps_.reshape(
B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
)
fmaps_pyramid.append(fmaps)
for i in range(self.corr_levels):
track_feat, track_feat_support = self.get_track_feat(
fmaps_pyramid[i],
queried_frames,
queried_coords / 2**i,
support_radius=self.corr_radius,
)
track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
D_coords = 2
coord_preds, vis_preds, confidence_preds = [], [], []
vis = torch.zeros((B, T, N), device=device).float()
confidence = torch.zeros((B, T, N), device=device).float()
coords = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()
r = 2 * self.corr_radius + 1
for it in range(iters):
coords = coords.detach() # B T N 2
coords_init = coords.view(B * T, N, 2)
corr_embs = []
corr_feats = []
for i in range(self.corr_levels):
corr_feat = self.get_correlation_feat(
fmaps_pyramid[i], coords_init / 2**i
)
track_feat_support = (
track_feat_support_pyramid[i]
.view(B, 1, r, r, N, self.latent_dim)
.squeeze(1)
.permute(0, 3, 1, 2, 4)
)
corr_volume = torch.einsum(
"btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
)
corr_emb = self.corr_mlp(corr_volume.reshape(B * T * N, r * r * r * r))
corr_embs.append(corr_emb)
corr_embs = torch.cat(corr_embs, dim=-1)
corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
transformer_input = [vis[..., None], confidence[..., None], corr_embs]
rel_coords_forward = coords[:, :-1] - coords[:, 1:]
rel_coords_backward = coords[:, 1:] - coords[:, :-1]
rel_coords_forward = torch.nn.functional.pad(
rel_coords_forward, (0, 0, 0, 0, 0, 1)
)
rel_coords_backward = torch.nn.functional.pad(
rel_coords_backward, (0, 0, 0, 0, 1, 0)
)
scale = (
torch.tensor(
[self.model_resolution[1], self.model_resolution[0]],
device=coords.device,
)
/ self.stride
)
rel_coords_forward = rel_coords_forward / scale
rel_coords_backward = rel_coords_backward / scale
rel_pos_emb_input = posenc(
torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
min_deg=0,
max_deg=10,
) # batch, num_points, num_frames, 84
transformer_input.append(rel_pos_emb_input)
x = (
torch.cat(transformer_input, dim=-1)
.permute(0, 2, 1, 3)
.reshape(B * N, T, -1)
)
x = x + self.interpolate_time_embed(x, T)
x = x.view(B, N, T, -1) # (B N) T D -> B N T D
delta = self.updateformer(
x,
add_space_attn=add_space_attn,
)
delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
delta_vis = delta[..., D_coords].permute(0, 2, 1)
delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
vis = vis + delta_vis
confidence = confidence + delta_confidence
coords = coords + delta_coords
coords_append = coords.clone()
coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
coord_preds.append(coords_append)
vis_preds.append(torch.sigmoid(vis))
confidence_preds.append(torch.sigmoid(confidence))
if is_train:
all_coords_predictions.append([coord[..., :2] for coord in coord_preds])
all_vis_predictions.append(vis_preds)
all_confidence_predictions.append(confidence_preds)
if is_train:
train_data = (
all_coords_predictions,
all_vis_predictions,
all_confidence_predictions,
torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
)
else:
train_data = None
return coord_preds[-1][..., :2], vis_preds[-1], confidence_preds[-1], train_data
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from cotracker.models.core.model_utils import sample_features5d, bilinear_sampler
from cotracker.models.core.embeddings import get_1d_sincos_pos_embed_from_grid
from cotracker.models.core.cotracker.blocks import Mlp, BasicEncoder
from cotracker.models.core.cotracker.cotracker import EfficientUpdateFormer
torch.manual_seed(0)
def posenc(x, min_deg, max_deg):
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
Instead of computing [sin(x), cos(x)], we use the trig identity
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
Args:
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
min_deg: int, the minimum (inclusive) degree of the encoding.
max_deg: int, the maximum (exclusive) degree of the encoding.
legacy_posenc_order: bool, keep the same ordering as the original tf code.
Returns:
encoded: torch.Tensor, encoded variables.
"""
if min_deg == max_deg:
return x
scales = torch.tensor(
[2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
)
xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
return torch.cat([x] + [four_feat], dim=-1)
class CoTrackerThreeBase(nn.Module):
def __init__(
self,
window_len=8,
stride=4,
corr_radius=3,
corr_levels=4,
num_virtual_tracks=64,
model_resolution=(384, 512),
add_space_attn=True,
linear_layer_for_vis_conf=True,
):
super(CoTrackerThreeBase, self).__init__()
self.window_len = window_len
self.stride = stride
self.corr_radius = corr_radius
self.corr_levels = corr_levels
self.hidden_dim = 256
self.latent_dim = 128
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, stride=stride)
highres_dim = 128
lowres_dim = 256
self.num_virtual_tracks = num_virtual_tracks
self.model_resolution = model_resolution
self.input_dim = 1110
self.updateformer = EfficientUpdateFormer(
space_depth=3,
time_depth=3,
input_dim=self.input_dim,
hidden_size=384,
output_dim=4,
mlp_ratio=4.0,
num_virtual_tracks=num_virtual_tracks,
add_space_attn=add_space_attn,
linear_layer_for_vis_conf=linear_layer_for_vis_conf,
)
self.corr_mlp = Mlp(in_features=49 * 49, hidden_features=384, out_features=256)
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(
1, window_len, 1
)
self.register_buffer(
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
)
def get_support_points(self, coords, r, reshape_back=True):
B, _, N, _ = coords.shape
device = coords.device
centroid_lvl = coords.reshape(B, N, 1, 1, 3)
dx = torch.linspace(-r, r, 2 * r + 1, device=device)
dy = torch.linspace(-r, r, 2 * r + 1, device=device)
xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
zgrid = torch.zeros_like(xgrid, device=device)
delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)
coords_lvl = centroid_lvl + delta_lvl
if reshape_back:
return coords_lvl.reshape(B, N, (2 * r + 1) ** 2, 3).permute(0, 2, 1, 3)
else:
return coords_lvl
def get_track_feat(self, fmaps, queried_frames, queried_coords, support_radius=0):
sample_frames = queried_frames[:, None, :, None]
sample_coords = torch.cat(
[
sample_frames,
queried_coords[:, None],
],
dim=-1,
)
support_points = self.get_support_points(sample_coords, support_radius)
support_track_feats = sample_features5d(fmaps, support_points)
return (
support_track_feats[:, None, support_track_feats.shape[1] // 2],
support_track_feats,
)
def get_correlation_feat(self, fmaps, queried_coords):
B, T, D, H_, W_ = fmaps.shape
N = queried_coords.shape[1]
r = self.corr_radius
sample_coords = torch.cat(
[torch.zeros_like(queried_coords[..., :1]), queried_coords], dim=-1
)[:, None]
support_points = self.get_support_points(sample_coords, r, reshape_back=False)
correlation_feat = bilinear_sampler(
fmaps.reshape(B * T, D, 1, H_, W_), support_points
)
return correlation_feat.view(B, T, D, N, (2 * r + 1), (2 * r + 1)).permute(
0, 1, 3, 4, 5, 2
)
def interpolate_time_embed(self, x, t):
previous_dtype = x.dtype
T = self.time_emb.shape[1]
if t == T:
return self.time_emb
time_emb = self.time_emb.float()
time_emb = F.interpolate(
time_emb.permute(0, 2, 1), size=t, mode="linear"
).permute(0, 2, 1)
return time_emb.to(previous_dtype)
class CoTrackerThreeOnline(CoTrackerThreeBase):
def __init__(self, **args):
super(CoTrackerThreeOnline, self).__init__(**args)
def init_video_online_processing(self):
self.online_ind = 0
self.online_track_feat = [None] * self.corr_levels
self.online_track_support = [None] * self.corr_levels
self.online_coords_predicted = None
self.online_vis_predicted = None
self.online_conf_predicted = None
def forward_window(
self,
fmaps_pyramid,
coords,
track_feat_support_pyramid,
vis=None,
conf=None,
attention_mask=None,
iters=4,
add_space_attn=False,
):
B, S, *_ = fmaps_pyramid[0].shape
N = coords.shape[2]
r = 2 * self.corr_radius + 1
coord_preds, vis_preds, conf_preds = [], [], []
for it in range(iters):
coords = coords.detach() # B T N 2
coords_init = coords.view(B * S, N, 2)
corr_embs = []
corr_feats = []
for i in range(self.corr_levels):
corr_feat = self.get_correlation_feat(
fmaps_pyramid[i], coords_init / 2**i
)
track_feat_support = (
track_feat_support_pyramid[i]
.view(B, 1, r, r, N, self.latent_dim)
.squeeze(1)
.permute(0, 3, 1, 2, 4)
)
corr_volume = torch.einsum(
"btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
)
corr_emb = self.corr_mlp(corr_volume.reshape(B * S * N, r * r * r * r))
corr_embs.append(corr_emb)
corr_embs = torch.cat(corr_embs, dim=-1)
corr_embs = corr_embs.view(B, S, N, corr_embs.shape[-1])
transformer_input = [vis, conf, corr_embs]
rel_coords_forward = coords[:, :-1] - coords[:, 1:]
rel_coords_backward = coords[:, 1:] - coords[:, :-1]
rel_coords_forward = torch.nn.functional.pad(
rel_coords_forward, (0, 0, 0, 0, 0, 1)
)
rel_coords_backward = torch.nn.functional.pad(
rel_coords_backward, (0, 0, 0, 0, 1, 0)
)
scale = (
torch.tensor(
[self.model_resolution[1], self.model_resolution[0]],
device=coords.device,
)
/ self.stride
)
rel_coords_forward = rel_coords_forward / scale
rel_coords_backward = rel_coords_backward / scale
rel_pos_emb_input = posenc(
torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
min_deg=0,
max_deg=10,
) # batch, num_points, num_frames, 84
transformer_input.append(rel_pos_emb_input)
x = (
torch.cat(transformer_input, dim=-1)
.permute(0, 2, 1, 3)
.reshape(B * N, S, -1)
)
x = x + self.interpolate_time_embed(x, S)
x = x.view(B, N, S, -1) # (B N) T D -> B N T D
delta = self.updateformer(x, add_space_attn=add_space_attn)
delta_coords = delta[..., :2].permute(0, 2, 1, 3)
delta_vis = delta[..., 2:3].permute(0, 2, 1, 3)
delta_conf = delta[..., 3:].permute(0, 2, 1, 3)
vis = vis + delta_vis
conf = conf + delta_conf
coords = coords + delta_coords
coord_preds.append(coords[..., :2] * float(self.stride))
vis_preds.append(vis[..., 0])
conf_preds.append(conf[..., 0])
return coord_preds, vis_preds, conf_preds
def forward(
self,
video,
queries,
iters=4,
is_train=False,
add_space_attn=True,
fmaps_chunk_size=200,
is_online=False,
):
"""Predict tracks
Args:
video (FloatTensor[B, T, 3]): input videos.
queries (FloatTensor[B, N, 3]): point queries.
iters (int, optional): number of updates. Defaults to 4.
is_train (bool, optional): enables training mode. Defaults to False.
Returns:
- coords_predicted (FloatTensor[B, T, N, 2]):
- vis_predicted (FloatTensor[B, T, N]):
- train_data: `None` if `is_train` is false, otherwise:
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- mask (BoolTensor[B, T, N]):
"""
B, T, C, H, W = video.shape
device = queries.device
assert H % self.stride == 0 and W % self.stride == 0
B, N, __ = queries.shape
# B = batch size
# S_trimmed = actual number of frames in the window
# N = number of tracks
# C = color channels (3 for RGB)
# E = positional embedding size
# LRR = local receptive field radius
# D = dimension of the transformer input tokens
# video = B T C H W
# queries = B N 3
# coords_init = B T N 2
# vis_init = B T N 1
S = self.window_len
assert S >= 2 # A tracker needs at least two frames to track something
if is_online:
assert T <= S, "Online mode: video chunk must be <= window size."
assert (
self.online_ind is not None
), "Call model.init_video_online_processing() first."
assert not is_train, "Training not supported in online mode."
step = S // 2 # How much the sliding window moves at every step
video = 2 * (video / 255.0) - 1.0
pad = (
S - T if is_online else (S - T % S) % S
) # We don't want to pad if T % S == 0
video = video.reshape(B, 1, T, C * H * W)
if pad > 0:
padding_tensor = video[:, :, -1:, :].expand(B, 1, pad, C * H * W)
video = torch.cat([video, padding_tensor], dim=2)
video = video.reshape(B, -1, C, H, W)
T_pad = video.shape[1]
# The first channel is the frame number
# The rest are the coordinates of points we want to track
dtype = video.dtype
queried_frames = queries[:, :, 0].long()
queried_coords = queries[..., 1:3]
queried_coords = queried_coords / self.stride
# We store our predictions here
coords_predicted = torch.zeros((B, T, N, 2), device=device)
vis_predicted = torch.zeros((B, T, N), device=device)
conf_predicted = torch.zeros((B, T, N), device=device)
if is_online:
if self.online_coords_predicted is None:
# Init online predictions with zeros
self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted
self.online_conf_predicted = conf_predicted
else:
# Pad online predictions with zeros for the current window
pad = min(step, T - step)
coords_predicted = F.pad(
self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
)
vis_predicted = F.pad(
self.online_vis_predicted, (0, 0, 0, pad), "constant"
)
conf_predicted = F.pad(
self.online_conf_predicted, (0, 0, 0, pad), "constant"
)
# We store our predictions here
all_coords_predictions, all_vis_predictions, all_confidence_predictions = (
[],
[],
[],
)
C_ = C
H4, W4 = H // self.stride, W // self.stride
# Compute convolutional features for the video or for the current chunk in case of online mode
if (not is_train) and (T > fmaps_chunk_size):
fmaps = []
for t in range(0, T, fmaps_chunk_size):
video_chunk = video[:, t : t + fmaps_chunk_size]
fmaps_chunk = self.fnet(video_chunk.reshape(-1, C_, H, W))
T_chunk = video_chunk.shape[1]
C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
else:
fmaps = self.fnet(video.reshape(-1, C_, H, W))
fmaps = fmaps.permute(0, 2, 3, 1)
fmaps = fmaps / torch.sqrt(
torch.maximum(
torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
torch.tensor(1e-12, device=fmaps.device),
)
)
fmaps = fmaps.permute(0, 3, 1, 2).reshape(
B, -1, self.latent_dim, H // self.stride, W // self.stride
)
fmaps = fmaps.to(dtype)
# We compute track features
fmaps_pyramid = []
track_feat_pyramid = []
track_feat_support_pyramid = []
fmaps_pyramid.append(fmaps)
for i in range(self.corr_levels - 1):
fmaps_ = fmaps.reshape(
B * T_pad, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
fmaps = fmaps_.reshape(
B, T_pad, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
)
fmaps_pyramid.append(fmaps)
if is_online:
sample_frames = queried_frames[:, None, :, None] # B 1 N 1
left = 0 if self.online_ind == 0 else self.online_ind + step
right = self.online_ind + S
sample_mask = (sample_frames >= left) & (sample_frames < right)
for i in range(self.corr_levels):
track_feat, track_feat_support = self.get_track_feat(
fmaps_pyramid[i],
queried_frames - self.online_ind if is_online else queried_frames,
queried_coords / 2**i,
support_radius=self.corr_radius,
)
if is_online:
if self.online_track_feat[i] is None:
self.online_track_feat[i] = torch.zeros_like(
track_feat, device=device
)
self.online_track_support[i] = torch.zeros_like(
track_feat_support, device=device
)
self.online_track_feat[i] += track_feat * sample_mask
self.online_track_support[i] += track_feat_support * sample_mask
track_feat_pyramid.append(
self.online_track_feat[i].repeat(1, T_pad, 1, 1)
)
track_feat_support_pyramid.append(
self.online_track_support[i].unsqueeze(1)
)
else:
track_feat_pyramid.append(track_feat.repeat(1, T_pad, 1, 1))
track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
D_coords = 2
coord_preds, vis_preds, confidence_preds = [], [], []
vis_init = torch.zeros((B, S, N, 1), device=device).float()
conf_init = torch.zeros((B, S, N, 1), device=device).float()
coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
num_windows = (T - S + step - 1) // step + 1
# We process only the current video chunk in the online mode
indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
for ind in indices:
if ind > 0:
overlap = S - step
copy_over = (queried_frames < ind + overlap)[
:, None, :, None
] # B 1 N 1
coords_prev = coords_predicted[:, ind : ind + overlap] / self.stride
padding_tensor = coords_prev[:, -1:, :, :].expand(-1, step, -1, -1)
coords_prev = torch.cat([coords_prev, padding_tensor], dim=1)
vis_prev = vis_predicted[:, ind : ind + overlap, :, None].clone()
padding_tensor = vis_prev[:, -1:, :, :].expand(-1, step, -1, -1)
vis_prev = torch.cat([vis_prev, padding_tensor], dim=1)
conf_prev = conf_predicted[:, ind : ind + overlap, :, None].clone()
padding_tensor = conf_prev[:, -1:, :, :].expand(-1, step, -1, -1)
conf_prev = torch.cat([conf_prev, padding_tensor], dim=1)
coords_init = torch.where(
copy_over.expand_as(coords_init), coords_prev, coords_init
)
vis_init = torch.where(
copy_over.expand_as(vis_init), vis_prev, vis_init
)
conf_init = torch.where(
copy_over.expand_as(conf_init), conf_prev, conf_init
)
attention_mask = (queried_frames < ind + S).reshape(B, 1, N) # B S N
# import ipdb; ipdb.set_trace()
coords, viss, confs = self.forward_window(
fmaps_pyramid=(
fmaps_pyramid
if is_online
else [fmap[:, ind : ind + S] for fmap in fmaps_pyramid]
),
coords=coords_init,
track_feat_support_pyramid=[
attention_mask[:, None, :, :, None] * tfeat
for tfeat in track_feat_support_pyramid
],
vis=vis_init,
conf=conf_init,
attention_mask=attention_mask.repeat(1, S, 1),
iters=iters,
add_space_attn=add_space_attn,
)
S_trimmed = (
T if is_online else min(T - ind, S)
) # accounts for last window duration
coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
vis_predicted[:, ind : ind + S] = viss[-1][:, :S_trimmed]
conf_predicted[:, ind : ind + S] = confs[-1][:, :S_trimmed]
if is_train:
all_coords_predictions.append(
[coord[:, :S_trimmed] for coord in coords]
)
all_vis_predictions.append(
[torch.sigmoid(vis[:, :S_trimmed]) for vis in viss]
)
all_confidence_predictions.append(
[torch.sigmoid(conf[:, :S_trimmed]) for conf in confs]
)
if is_online:
self.online_ind += step
self.online_coords_predicted = coords_predicted
self.online_vis_predicted = vis_predicted
self.online_conf_predicted = conf_predicted
vis_predicted = torch.sigmoid(vis_predicted)
conf_predicted = torch.sigmoid(conf_predicted)
if is_train:
valid_mask = (
queried_frames[:, None]
<= torch.arange(0, T, device=device)[None, :, None]
)
train_data = (
all_coords_predictions,
all_vis_predictions,
all_confidence_predictions,
valid_mask,
)
else:
train_data = None
return coords_predicted, vis_predicted, conf_predicted, train_data
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import reduce_masked_mean
import torch.nn as nn
from typing import List
def sequence_loss(
flow_preds,
flow_gt,
valids,
vis=None,
gamma=0.8,
add_huber_loss=False,
loss_only_for_visible=False,
):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
B, S2, N = valids[j].shape
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i]
if add_huber_loss:
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
else:
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N
valid_ = valids[j].clone()
if loss_only_for_visible:
valid_ = valid_ * vis[j]
flow_loss += i_weight * reduce_masked_mean(i_loss, valid_)
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss
return total_flow_loss / len(flow_gt)
def huber_loss(x, y, delta=1.0):
"""Calculate element-wise Huber loss between x and y"""
diff = x - y
abs_diff = diff.abs()
flag = (abs_diff <= delta).float()
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
def sequence_BCE_loss(vis_preds, vis_gts):
total_bce_loss = 0.0
for j in range(len(vis_preds)):
n_predictions = len(vis_preds[j])
bce_loss = 0.0
for i in range(n_predictions):
vis_loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j])
bce_loss += vis_loss
bce_loss = bce_loss / n_predictions
total_bce_loss += bce_loss
return total_bce_loss / len(vis_preds)
def sequence_prob_loss(
tracks: torch.Tensor,
confidence: torch.Tensor,
target_points: torch.Tensor,
visibility: torch.Tensor,
expected_dist_thresh: float = 12.0,
):
"""Loss for classifying if a point is within pixel threshold of its target."""
# Points with an error larger than 12 pixels are likely to be useless; marking
# them as occluded will actually improve Jaccard metrics and give
# qualitatively better results.
total_logprob_loss = 0.0
for j in range(len(tracks)):
n_predictions = len(tracks[j])
logprob_loss = 0.0
for i in range(n_predictions):
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
valid = (err <= expected_dist_thresh**2).float()
logprob = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
logprob *= visibility[j]
logprob = torch.mean(logprob, dim=[1, 2])
logprob_loss += logprob
logprob_loss = logprob_loss / n_predictions
total_logprob_loss += logprob_loss
return total_logprob_loss / len(tracks)
def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
if mask is None:
return data.mean(dim=dim, keepdim=True)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
return mask_mean
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
if mask is None:
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
mask_var = torch.sum(
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
) / torch.clamp(mask_sum, min=1.0)
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple, Union
import torch
def get_2d_sincos_pos_embed(
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
) -> torch.Tensor:
"""
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
Args:
- embed_dim: The embedding dimension.
- grid_size: The grid size.
Returns:
- pos_embed: The generated 2D positional embedding.
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = torch.arange(grid_size_h, dtype=torch.float)
grid_w = torch.arange(grid_size_w, dtype=torch.float)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- grid: The grid to generate the embedding from.
Returns:
- emb: The generated 2D positional embedding.
"""
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor
) -> torch.Tensor:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.double)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb[None].float()
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
"""
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
Args:
- xy: The coordinates to generate the embedding from.
- C: The size of the embedding.
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
Returns:
- pe: The generated 2D positional embedding.
"""
B, N, D = xy.shape
assert D == 2
x = xy[:, :, 0:1]
y = xy[:, :, 1:2]
div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
return pe
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import random
import torch
import torch.nn.functional as F
from typing import Optional, Tuple
EPS = 1e-6
def smart_cat(tensor1, tensor2, dim):
if tensor1 is None:
return tensor2
return torch.cat([tensor1, tensor2], dim=dim)
def get_uniformly_sampled_pts(
size: int,
num_frames: int,
extent: Tuple[float, ...],
device: Optional[torch.device] = torch.device("cpu"),
):
time_points = torch.randint(low=0, high=num_frames, size=(size, 1), device=device)
space_points = torch.rand(size, 2, device=device) * torch.tensor(
[extent[1], extent[0]], device=device
)
points = torch.cat((time_points, space_points), dim=1)
return points[None]
def get_superpoint_sampled_pts(
video,
size: int,
num_frames: int,
extent: Tuple[float, ...],
device: Optional[torch.device] = torch.device("cpu"),
):
extractor = SuperPoint(max_num_keypoints=48).eval().cuda()
points = list()
for _ in range(8):
frame_num = random.randint(0, int(num_frames * 0.25))
key_points = extractor.extract(
video[0, frame_num, :, :, :] / 255.0, resize=None
)["keypoints"]
frame_tensor = torch.full((1, key_points.shape[1], 1), frame_num).cuda()
points.append(torch.cat([frame_tensor.cuda(), key_points], dim=2))
return torch.cat(points, dim=1)[:, :size, :]
def get_sift_sampled_pts(
video,
size: int,
num_frames: int,
extent: Tuple[float, ...],
device: Optional[torch.device] = torch.device("cpu"),
num_sampled_frames: int = 8,
sampling_length_percent: float = 0.25,
):
import cv2
# assert size == 384, "hardcoded for experiment"
sift = cv2.SIFT_create(nfeatures=size // num_sampled_frames)
points = list()
for _ in range(num_sampled_frames):
frame_num = random.randint(0, int(num_frames * sampling_length_percent))
key_points, _ = sift.detectAndCompute(
video[0, frame_num, :, :, :]
.cpu()
.permute(1, 2, 0)
.numpy()
.astype(np.uint8),
None,
)
for kp in key_points:
points.append([frame_num, int(kp.pt[0]), int(kp.pt[1])])
return torch.tensor(points[:size], device=device)[None]
def get_points_on_a_grid(
size: int,
extent: Tuple[float, ...],
center: Optional[Tuple[float, ...]] = None,
device: Optional[torch.device] = torch.device("cpu"),
):
r"""Get a grid of points covering a rectangular region
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
:attr:`size` grid fo points distributed to cover a rectangular area
specified by `extent`.
The `extent` is a pair of integer :math:`(H,W)` specifying the height
and width of the rectangle.
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
specifying the vertical and horizontal center coordinates. The center
defaults to the middle of the extent.
Points are distributed uniformly within the rectangle leaving a margin
:math:`m=W/64` from the border.
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
points :math:`P_{ij}=(x_i, y_i)` where
.. math::
P_{ij} = \left(
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
\right)
Points are returned in row-major order.
Args:
size (int): grid size.
extent (tuple): height and with of the grid extent.
center (tuple, optional): grid center.
device (str, optional): Defaults to `"cpu"`.
Returns:
Tensor: grid.
"""
if size == 1:
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
if center is None:
center = [extent[0] / 2, extent[1] / 2]
margin = extent[1] / 64
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
grid_y, grid_x = torch.meshgrid(
torch.linspace(*range_y, size, device=device),
torch.linspace(*range_x, size, device=device),
indexing="ij",
)
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
r"""Masked mean
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
over a mask :attr:`mask`, returning
.. math::
\text{output} =
\frac
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
{\epsilon + \sum_{i=1}^N \text{mask}_i}
where :math:`N` is the number of elements in :attr:`input` and
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
division by zero.
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
Optionally, the dimension can be kept in the output by setting
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
the same dimension as :attr:`input`.
The interface is similar to `torch.mean()`.
Args:
inout (Tensor): input tensor.
mask (Tensor): mask.
dim (int, optional): Dimension to sum over. Defaults to None.
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
Returns:
Tensor: mean tensor.
"""
mask = mask.expand_as(input)
prod = input * mask
if dim is None:
numer = torch.sum(prod)
denom = torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / (EPS + denom)
return mean
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes = input.shape[2:]
assert len(sizes) in [2, 3]
if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]]
if align_corners:
coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
)
else:
coords = coords * torch.tensor(
[2 / size for size in reversed(sizes)], device=coords.device
)
coords -= 1
return F.grid_sample(
input, coords, align_corners=align_corners, padding_mode=padding_mode
)
def sample_features4d(input, coords):
r"""Sample spatial features
`sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B,
R, C)`.
Args:
input (Tensor): spatial features.
coords (Tensor): points.
Returns:
Tensor: sampled features.
"""
B, _, _, _ = input.shape
# B R 2 -> B R 1 2
coords = coords.unsqueeze(2)
# B C R 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 1, 3).view(
B, -1, feats.shape[1] * feats.shape[3]
) # B C R 1 -> B R C
def sample_features5d(input, coords):
r"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args:
input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points.
Returns:
Tensor: sampled features.
"""
B, T, _, _, _ = input.shape
# B T C H W -> B C T H W
input = input.permute(0, 2, 1, 3, 4)
# B R1 R2 3 -> B R1 R2 1 3
coords = coords.unsqueeze(3)
# B C R1 R2 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 3, 1, 4).view(
B, feats.shape[2], feats.shape[3], feats.shape[1]
) # B C R1 R2 1 -> B R1 R2 C
def get_grid(
height,
width,
shape=None,
dtype="torch",
device="cpu",
align_corners=True,
normalize=True,
):
H, W = height, width
S = shape if shape else []
if align_corners:
x = torch.linspace(0, 1, W, device=device)
y = torch.linspace(0, 1, H, device=device)
if not normalize:
x = x * (W - 1)
y = y * (H - 1)
else:
x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
if not normalize:
x = x * W
y = y * H
x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
x = x.view(*x_view).expand(*exp)
y = y.view(*y_view).expand(*exp)
grid = torch.stack([x, y], dim=-1)
if dtype == "numpy":
grid = grid.numpy()
return grid
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes = input.shape[2:]
assert len(sizes) in [2, 3]
if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]]
if align_corners:
coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
)
else:
coords = coords * torch.tensor(
[2 / size for size in reversed(sizes)], device=coords.device
)
coords -= 1
return F.grid_sample(
input, coords, align_corners=align_corners, padding_mode=padding_mode
)
def round_to_multiple_of_4(n):
return round(n / 4) * 4
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Tuple
from cotracker.models.core.cotracker.cotracker3_offline import CoTrackerThreeOffline
from cotracker.models.core.model_utils import (
get_points_on_a_grid,
get_uniformly_sampled_pts,
get_sift_sampled_pts,
)
import numpy as np
import sys
from torchvision.transforms import Compose
from tqdm import tqdm
from cotracker.models.core.model_utils import bilinear_sampler
class EvaluationPredictor(torch.nn.Module):
def __init__(
self,
cotracker_model: CoTrackerThreeOffline,
interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 5,
local_grid_size: int = 8,
single_point: bool = True,
sift_size: int = 0,
num_uniformly_sampled_pts: int = 0,
n_iters: int = 6,
local_extent: int = 50,
) -> None:
super(EvaluationPredictor, self).__init__()
self.grid_size = grid_size
self.local_grid_size = local_grid_size
self.sift_size = sift_size
self.single_point = single_point
self.interp_shape = interp_shape
self.n_iters = n_iters
self.num_uniformly_sampled_pts = num_uniformly_sampled_pts
self.model = cotracker_model
self.local_extent = local_extent
self.model.eval()
def forward(self, video, queries):
queries = queries.clone()
B, T, C, H, W = video.shape
B, N, D = queries.shape
assert D == 3
assert B == 1
interp_shape = self.interp_shape
video = video.reshape(B * T, C, H, W)
video = F.interpolate(
video, tuple(interp_shape), mode="bilinear", align_corners=True
)
video = video.reshape(B, T, 3, interp_shape[0], interp_shape[1])
device = video.device
queries[:, :, 1] *= (interp_shape[1] - 1) / (W - 1)
queries[:, :, 2] *= (interp_shape[0] - 1) / (H - 1)
if self.single_point:
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
conf_e = torch.zeros((B, T, N), device=device)
for pind in range((N)):
query = queries[:, pind : pind + 1]
t = query[0, 0, 0].long()
start_ind = 0
traj_e_pind, vis_e_pind, conf_e_pind = self._process_one_point(
video[:,start_ind:], query
)
traj_e[:, start_ind:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, start_ind:, pind : pind + 1] = vis_e_pind[:, :, :1]
conf_e[:, start_ind:, pind : pind + 1] = conf_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
device
) #
queries = torch.cat([queries, xy], dim=1) #
if self.num_uniformly_sampled_pts > 0:
xy = get_uniformly_sampled_pts(
self.num_uniformly_sampled_pts,
video.shape[1],
video.shape[3:],
device=device,
)
queries = torch.cat([queries, xy], dim=1) #
sift_size = self.sift_size
if sift_size > 0:
xy = get_sift_sampled_pts(video, sift_size, T, [H, W], device=device)
if xy.shape[1] == sift_size:
queries = torch.cat([queries, xy], dim=1) #
else:
sift_size = 0
preds = self.model(video=video, queries=queries, iters=self.n_iters)
traj_e, vis_e = preds[0], preds[1]
conf_e = None
if len(preds) > 3:
conf_e = preds[2]
if (
sift_size > 0
or self.grid_size > 0
or self.num_uniformly_sampled_pts > 0
):
traj_e = traj_e[
:,
:,
: -self.grid_size**2 - sift_size - self.num_uniformly_sampled_pts,
]
vis_e = vis_e[
:,
:,
: -self.grid_size**2 - sift_size - self.num_uniformly_sampled_pts,
]
if conf_e is not None:
conf_e = conf_e[
:,
:,
: -self.grid_size**2
- sift_size
- self.num_uniformly_sampled_pts,
]
traj_e[:, :, :, 0] *= (W - 1) / float(interp_shape[1] - 1)
traj_e[:, :, :, 1] *= (H - 1) / float(interp_shape[0] - 1)
if conf_e is not None:
vis_e = vis_e * conf_e
return traj_e, vis_e
def _process_one_point(self, video, query):
t = query[0, 0, 0].long()
B, T, C, H, W = video.shape
device = query.device
if self.local_grid_size > 0:
xy_target = get_points_on_a_grid(
self.local_grid_size,
(self.local_extent, self.local_extent),
[query[0, 0, 2].item(), query[0, 0, 1].item()],
)
xy_target = torch.cat(
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
).to(
device
) #
query = torch.cat([query, xy_target], dim=1) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1) #
sift_size = self.sift_size
if sift_size > 0:
xy = get_sift_sampled_pts(video, sift_size, T, [H, W], device=device)
sift_size = xy.shape[1]
if sift_size > 0:
query = torch.cat([query, xy], dim=1) #
num_uniformly_sampled_pts = self.sift_size - sift_size
if num_uniformly_sampled_pts > 0:
xy2 = get_uniformly_sampled_pts(
num_uniformly_sampled_pts,
video.shape[1],
video.shape[3:],
device=device,
)
query = torch.cat([query, xy2], dim=1) #
if self.num_uniformly_sampled_pts > 0:
xy = get_uniformly_sampled_pts(
self.num_uniformly_sampled_pts,
video.shape[1],
video.shape[3:],
device=device,
)
query = torch.cat([query, xy], dim=1) #
traj_e_pind, vis_e_pind, conf_e_pind, __ = self.model(
video=video, queries=query, iters=self.n_iters
)
return traj_e_pind[..., :2], vis_e_pind, conf_e_pind
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
from cotracker.models.build_cotracker import build_cotracker
class CoTrackerPredictor(torch.nn.Module):
def __init__(
self,
checkpoint="./checkpoints/scaled_offline.pth",
offline=True,
v2=False,
window_len=60,
):
super().__init__()
self.v2 = v2
self.support_grid_size = 6
model = build_cotracker(
checkpoint,
v2=v2,
offline=offline,
window_len=window_len,
)
self.interp_shape = model.model_resolution
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video, # (B, T, 3, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks(
video,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
else:
tracks, visibilities = self._compute_sparse_tracks(
video,
queries,
segm_mask,
grid_size,
add_support_grid=(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
return tracks, visibilities
def _compute_dense_tracks(
self, video, grid_query_frame, grid_size=80, backward_tracking=False
):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((video.shape[0], grid_width * grid_height, 3)).to(video.device)
grid_pts[:, :, 0] = grid_query_frame
for offset in range(grid_step * grid_step):
print(f"step {offset} / {grid_step * grid_step}")
ox = offset % grid_step
oy = offset // grid_step
grid_pts[:, :, 1] = (
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
)
grid_pts[:, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
tracks_step, visibilities_step = self._compute_sparse_tracks(
video=video,
queries=grid_pts,
backward_tracking=backward_tracking,
)
tracks = smart_cat(tracks, tracks_step, dim=2)
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
return tracks, visibilities
def _compute_sparse_tracks(
self,
video,
queries,
segm_mask=None,
grid_size=0,
add_support_grid=False,
grid_query_frame=0,
backward_tracking=False,
):
B, T, C, H, W = video.shape
video = video.reshape(B * T, C, H, W)
video = F.interpolate(
video, tuple(self.interp_shape), mode="bilinear", align_corners=True
)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
B, N, D = queries.shape
assert D == 3
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
elif grid_size > 0:
grid_pts = get_points_on_a_grid(
grid_size, self.interp_shape, device=video.device
)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"
)
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
].bool()
grid_pts = grid_pts[:, point_mask]
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
).repeat(B, 1, 1)
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device
)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
)
grid_pts = grid_pts.repeat(B, 1, 1)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, visibilities, *_ = self.model.forward(
video=video, queries=queries, iters=6
)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size**2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size**2]
visibilities = visibilities[:, :, : -self.support_grid_size**2]
thr = 0.9
visibilities = visibilities > thr
# correct query-point predictions
# see https://github.com/facebookresearch/co-tracker/issues/28
# TODO: batchify
for i in range(len(queries)):
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True
tracks *= tracks.new_tensor(
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
)
return tracks, visibilities
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, inv_visibilities, *_ = self.model(
video=inv_video, queries=inv_queries, iters=6
)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities
class CoTrackerOnlinePredictor(torch.nn.Module):
def __init__(
self,
checkpoint="./checkpoints/scaled_online.pth",
offline=False,
v2=False,
window_len=16,
):
super().__init__()
self.v2 = v2
self.support_grid_size = 6
model = build_cotracker(checkpoint, v2=v2, offline=False, window_len=window_len)
self.interp_shape = model.model_resolution
self.step = model.window_len // 2
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video_chunk,
is_first_step: bool = False,
queries: torch.Tensor = None,
grid_size: int = 5,
grid_query_frame: int = 0,
add_support_grid=False,
):
B, T, C, H, W = video_chunk.shape
# Initialize online video processing and save queried points
# This needs to be done before processing *each new video*
if is_first_step:
self.model.init_video_online_processing()
if queries is not None:
B, N, D = queries.shape
self.N = N
assert D == 3
queries = queries.clone()
queries[:, :, 1:] *= queries.new_tensor(
[
(self.interp_shape[1] - 1) / (W - 1),
(self.interp_shape[0] - 1) / (H - 1),
]
)
if add_support_grid:
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video_chunk.device
)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
)
queries = torch.cat([queries, grid_pts], dim=1)
elif grid_size > 0:
grid_pts = get_points_on_a_grid(
grid_size, self.interp_shape, device=video_chunk.device
)
self.N = grid_size**2
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
self.queries = queries
return (None, None)
video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate(
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
)
video_chunk = video_chunk.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
)
if self.v2:
tracks, visibilities, __ = self.model(
video=video_chunk, queries=self.queries, iters=6, is_online=True
)
else:
tracks, visibilities, confidence, __ = self.model(
video=video_chunk, queries=self.queries, iters=6, is_online=True
)
if add_support_grid:
tracks = tracks[:,:,:self.N]
visibilities = visibilities[:,:,:self.N]
if not self.v2:
confidence = confidence[:,:,:self.N]
if not self.v2:
visibilities = visibilities * confidence
thr = 0.6
return (
tracks
* tracks.new_tensor(
[
(W - 1) / (self.interp_shape[1] - 1),
(H - 1) / (self.interp_shape[0] - 1),
]
),
visibilities > thr,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import torch
import signal
import socket
from torch.utils.data import ConcatDataset
from cotracker.datasets.utils import collate_fn, collate_fn_train
from torch.utils.tensorboard import SummaryWriter
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.models.evaluation_predictor import EvaluationPredictor
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def get_eval_dataloader(dataset_root, ds_name):
from cotracker.datasets.tap_vid_datasets import TapVidDataset
collate_fn_local = collate_fn
if ds_name == "dynamic_replica":
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
eval_dataset = DynamicReplicaDataset(
root=os.path.join(dataset_root, "dynamic_replica"),
sample_len=300,
only_first_n_samples=1,
rgbd_input=False,
)
elif ds_name == "tapvid_davis_first":
data_root = os.path.join(dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(
dataset_type="davis", data_root=data_root, queried_first=True
)
elif ds_name == "tapvid_davis_strided":
data_root = os.path.join(dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(
dataset_type="davis", data_root=data_root, queried_first=False
)
elif ds_name == "tapvid_kinetics_first":
eval_dataset = TapVidDataset(
dataset_type="kinetics",
data_root=os.path.join(dataset_root, "tapvid", "tapvid_kinetics"),
)
elif ds_name == "tapvid_stacking":
eval_dataset = TapVidDataset(
dataset_type="stacking",
data_root=os.path.join(
dataset_root, "tapvid", "tapvid_rgb_stacking", "tapvid_rgb_stacking.pkl"
),
)
elif ds_name == "tapvid_robotap":
eval_dataset = TapVidDataset(
dataset_type="robotap",
data_root=os.path.join(dataset_root, "tapvid", "tapvid_robotap"),
)
elif ds_name == "kubric":
from cotracker.datasets.kubric_movif_dataset import KubricMovifDataset
eval_dataset = KubricMovifDataset(
data_root=os.path.join(
args.dataset_root, "kubric/kubric_movi_f_120_frames_dense/movi_f"
),
traj_per_sample=1024,
use_augs=False,
split="valid",
sample_vis_1st_frame=True,
)
collate_fn_local = collate_fn_train
eval_dataloader_dr = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn_local,
)
return eval_dataloader_dr
def get_train_dataset(args):
dataset = None
if "kubric" in args.train_datasets:
from cotracker.datasets import kubric_movif_dataset
kubric = kubric_movif_dataset.KubricMovifDataset(
data_root=os.path.join(
args.dataset_root, "kubric/kubric_movi_f_120_frames_dense/movi_f"
),
crop_size=args.crop_size,
seq_len=args.sequence_len,
traj_per_sample=args.traj_per_sample,
sample_vis_last_frame=args.query_sampling_method is not None
and ("random" in args.query_sampling_method),
use_augs=not args.dont_use_augs,
random_seq_len=args.random_seq_len,
random_frame_rate=args.random_frame_rate,
random_number_traj=args.random_number_traj,
)
if dataset is None:
dataset = ConcatDataset(4 * [kubric])
else:
dataset = ConcatDataset(4 * [kubric] + [dataset])
print("add kubric to train", len(dataset))
if "dr" in args.train_datasets:
dr = DynamicReplicaDataset(
root=os.path.join(args.dataset_root, "dynamic_replica"),
sample_len=args.sequence_len,
split="train",
traj_per_sample=args.traj_per_sample,
crop_size=args.crop_size,
)
if dataset is None:
dataset = dr
else:
dataset = ConcatDataset([dr] + [dataset])
return dataset
def run_test_eval(evaluator, model, dataloaders, writer, step, query_random=False):
model.eval()
for ds_name, dataloader in dataloaders:
visualize_every = 1
grid_size = 5
num_uniformly_sampled_pts = 0
if ds_name == "dynamic_replica":
visualize_every = 8
grid_size = 0
elif ds_name == "kubric":
visualize_every = 5
grid_size = 0
elif "davis" in ds_name or "tapvid_stacking" in ds_name:
visualize_every = 5
elif "robotap" in ds_name:
visualize_every = 20
elif "kinetics" in ds_name:
visualize_every = 50
if query_random:
grid_size = 0
num_uniformly_sampled_pts = 100
predictor = EvaluationPredictor(
model.module.module,
grid_size=grid_size,
local_grid_size=0,
single_point=False,
num_uniformly_sampled_pts=num_uniformly_sampled_pts,
n_iters=6,
)
if torch.cuda.is_available():
predictor.model = predictor.model.cuda()
metrics = evaluator.evaluate_sequence(
model=predictor,
test_dataloader=dataloader,
dataset_name=ds_name,
train_mode=True,
writer=writer,
step=step,
visualize_every=visualize_every,
)
if ds_name == "dynamic_replica" or ds_name == "kubric":
metrics = {
f"{ds_name}_avg_{k}": v
for k, v in metrics["avg"].items()
if not ("1" in k or "2" in k or "4" in k or "8" in k)
}
if "tapvid" in ds_name:
metrics = {
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"],
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"],
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"],
}
writer.add_scalars(f"Eval_{ds_name}", metrics, step)
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler, ckpt_path):
self.model = model
self.scheduler = scheduler
self.ckpt_path = ckpt_path
self.total_steps = 0
self.running_loss = {}
self.writer = SummaryWriter(log_dir=os.path.join(ckpt_path, "runs"))
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / Logger.SUM_FREQ
for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
)
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
for k in self.running_loss:
self.writer.add_scalar(
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
)
self.running_loss[k] = 0.0
def push(self, metrics, task):
self.total_steps += 1
for key in metrics:
task_key = str(key) + "_" + task
if task_key not in self.running_loss:
self.running_loss[task_key] = 0.0
self.running_loss[task_key] += metrics[key]
if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1:
self._print_training_status()
self.running_loss = {}
def write_dict(self, results):
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(self.ckpt_path, "runs"))
for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
def close(self):
self.writer.close()
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import imageio
import torch
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
def read_video_from_path(path):
try:
reader = imageio.get_reader(path)
except Exception as e:
print("Error opening video file: ", e)
return None
frames = []
for i, im in enumerate(reader):
frames.append(np.array(im))
return np.stack(frames)
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True, color_alpha=None):
# Create a draw object
draw = ImageDraw.Draw(rgb)
# Calculate the bounding box of the circle
left_up_point = (coord[0] - radius, coord[1] - radius)
right_down_point = (coord[0] + radius, coord[1] + radius)
# Draw the circle
color = tuple(list(color) + [color_alpha if color_alpha is not None else 255])
draw.ellipse(
[left_up_point, right_down_point],
fill=tuple(color) if visible else None,
outline=tuple(color),
)
return rgb
def draw_line(rgb, coord_y, coord_x, color, linewidth):
draw = ImageDraw.Draw(rgb)
draw.line(
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
fill=tuple(color),
width=linewidth,
)
return rgb
def add_weighted(rgb, alpha, original, beta, gamma):
return (rgb * alpha + original * beta + gamma).astype("uint8")
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame=0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
opacity: float = 1.0,
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
color_alpha = int(opacity * 255)
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
visibility=visibility,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
color_alpha=color_alpha,
)
if save_video:
self.save_video(res_video, filename=filename, writer=writer, step=step)
return res_video
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
filename,
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
# Prepare the video file path
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
# Create a writer object
video_writer = imageio.get_writer(save_path, fps=self.fps)
# Write frames to the video file
for frame in wide_list[2:-1]:
video_writer.append_data(frame)
video_writer.close()
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame=0,
compensate_for_camera_motion=False,
color_alpha: int = 255,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 2
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
import flow_vis
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
y_min, y_max = (
tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if isinstance(query_frame, torch.Tensor):
query_frame_ = query_frame[n]
else:
query_frame_ = query_frame
color = self.color_map(norm(tracks[query_frame_, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(query_frame + 1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
# draw points
for t in range(T):
img = Image.fromarray(np.uint8(res_video[t]))
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
visibile = True
if visibility is not None:
visibile = visibility[0, t, i]
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
img = draw_circle(
img,
coord=coord,
radius=int(self.linewidth * 2),
color=vector_colors[t, i].astype(int),
visible=visibile,
color_alpha=color_alpha,
)
res_video[t] = np.array(img)
# construct the final rgb sequence
if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
rgb = Image.fromarray(np.uint8(rgb))
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
rgb = draw_line(
rgb,
coord_y,
coord_x,
vector_color[i].astype(int),
self.linewidth,
)
if self.tracks_leave_trace > 0:
rgb = Image.fromarray(
np.uint8(
add_weighted(
np.array(rgb), alpha, np.array(original), 1 - alpha, 0
)
)
)
rgb = np.array(rgb)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211, 0, 0))
rgb = Image.fromarray(np.uint8(rgb))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
rgb = draw_line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
)
rgb = np.array(rgb)
return rgb
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "3.0.0"
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import argparse
import numpy as np
from PIL import Image
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.predictor import CoTrackerPredictor
DEFAULT_DEVICE = (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# if DEFAULT_DEVICE == "mps":
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--video_path",
default="./assets/apple.mp4",
help="path to a video",
)
parser.add_argument(
"--mask_path",
default="./assets/apple_mask.png",
help="path to a segmentation mask",
)
parser.add_argument(
"--checkpoint",
# default="./checkpoints/cotracker.pth",
default=None,
help="CoTracker model parameters",
)
parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
parser.add_argument(
"--grid_query_frame",
type=int,
default=0,
help="Compute dense and grid tracks starting from this frame",
)
parser.add_argument(
"--backward_tracking",
action="store_true",
help="Compute tracks in both directions, not only forward",
)
parser.add_argument(
"--use_v2_model",
action="store_true",
help="Pass it if you wish to use CoTracker2, CoTracker++ is the default now",
)
parser.add_argument(
"--offline",
action="store_true",
help="Pass it if you would like to use the offline model, in case of online don't pass it",
)
args = parser.parse_args()
# load the input video frame by frame
video = read_video_from_path(args.video_path)
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
segm_mask = torch.from_numpy(segm_mask)[None, None]
if args.checkpoint is not None:
if args.use_v2_model:
model = CoTrackerPredictor(checkpoint=args.checkpoint, v2=args.use_v2_model)
else:
if args.offline:
window_len = 60
else:
window_len = 16
model = CoTrackerPredictor(
checkpoint=args.checkpoint,
v2=args.use_v2_model,
offline=args.offline,
window_len=window_len,
)
else:
model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline")
model = model.to(DEFAULT_DEVICE)
video = video.to(DEFAULT_DEVICE)
pred_tracks, pred_visibility = model(
video,
grid_size=args.grid_size,
grid_query_frame=args.grid_query_frame,
backward_tracking=args.backward_tracking,
# segm_mask=segm_mask
)
print("computed")
# save a video with predicted tracks
seq_name = args.video_path.split("/")[-1]
vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
vis.visualize(
video,
pred_tracks,
pred_visibility,
query_frame=0 if args.backward_tracking else args.grid_query_frame,
)
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = _build
O = -a
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
Models
======
CoTracker models:
.. currentmodule:: cotracker.models
Model Utils
-----------
.. automodule:: cotracker.models.core.model_utils
:members:
:undoc-members:
:show-inheritance:
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment