"vscode:/vscode.git/clone" did not exist on "74788b487c3c3aa48715dfeb700871c61a63836d"
Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Metadata-Version: 2.2
Name: cotracker
Version: 3.0
License-File: LICENSE.md
Provides-Extra: all
Requires-Dist: matplotlib; extra == "all"
Provides-Extra: dev
Requires-Dist: flake8; extra == "dev"
Requires-Dist: black; extra == "dev"
Dynamic: provides-extra
CODE_OF_CONDUCT.md
CONTRIBUTING.md
LICENSE.md
README.md
demo.py
hubconf.py
launch_training_kubric_offline.sh
launch_training_kubric_online.sh
launch_training_scaling_offline.sh
launch_training_scaling_online.sh
online_demo.py
setup.py
train_on_kubric.py
train_on_real_data.py
assets/apple.mp4
assets/apple_mask.png
assets/bmx-bumps.gif
assets/teaser.png
cotracker/__init__.py
cotracker/predictor.py
cotracker/version.py
cotracker.egg-info/PKG-INFO
cotracker.egg-info/SOURCES.txt
cotracker.egg-info/dependency_links.txt
cotracker.egg-info/requires.txt
cotracker.egg-info/top_level.txt
cotracker/datasets/__init__.py
cotracker/datasets/dataclass_utils.py
cotracker/datasets/dr_dataset.py
cotracker/datasets/kubric_movif_dataset.py
cotracker/datasets/real_dataset.py
cotracker/datasets/tap_vid_datasets.py
cotracker/datasets/utils.py
cotracker/evaluation/__init__.py
cotracker/evaluation/evaluate.py
cotracker/evaluation/configs/eval_dynamic_replica.yaml
cotracker/evaluation/configs/eval_tapvid_davis_first.yaml
cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml
cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml
cotracker/evaluation/configs/eval_tapvid_robotap_first.yaml
cotracker/evaluation/configs/eval_tapvid_stacking_first.yaml
cotracker/evaluation/configs/eval_tapvid_stacking_strided.yaml
cotracker/evaluation/core/__init__.py
cotracker/evaluation/core/eval_utils.py
cotracker/evaluation/core/evaluator.py
cotracker/models/__init__.py
cotracker/models/bootstap_predictor.py
cotracker/models/build_cotracker.py
cotracker/models/evaluation_predictor.py
cotracker/models/core/__init__.py
cotracker/models/core/embeddings.py
cotracker/models/core/model_utils.py
cotracker/models/core/cotracker/__init__.py
cotracker/models/core/cotracker/blocks.py
cotracker/models/core/cotracker/cotracker.py
cotracker/models/core/cotracker/cotracker3_offline.py
cotracker/models/core/cotracker/cotracker3_online.py
cotracker/models/core/cotracker/losses.py
cotracker/utils/__init__.py
cotracker/utils/train_utils.py
cotracker/utils/visualizer.py
docs/Makefile
docs/source/conf.py
docs/source/index.rst
docs/source/references.bib
docs/source/apis/models.rst
docs/source/apis/utils.rst
gradio_demo/app.py
gradio_demo/requirements.txt
gradio_demo/videos/apple.mp4
gradio_demo/videos/backpack.mp4
gradio_demo/videos/bear.mp4
gradio_demo/videos/cat.mp4
gradio_demo/videos/paragliding-launch.mp4
gradio_demo/videos/paragliding.mp4
gradio_demo/videos/pillow.mp4
gradio_demo/videos/teddy.mp4
notebooks/demo.ipynb
tests/test_bilinear_sample.py
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import dataclasses
import numpy as np
from dataclasses import Field, MISSING
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
_X = TypeVar("_X")
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
return res
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if typeannot is Any:
return dlist
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
if any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
is_optional, contained_type = _resolve_optional(typeannot)
if is_optional:
return _dataclass_list_from_dict_list(dlist, contained_type)
# otherwise, we dispatch by the type of the provided annotation to convert to
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls.__annotations__.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp)
for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp)
for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
all_vals_res_iter = iter(all_vals_res)
return [cls(zip(k, all_vals_res_iter)) for k in keys]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import gzip
import torch
import numpy as np
import torch.utils.data as data
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int]
@dataclass
class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None
trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset):
def __init__(
self,
root,
split="valid",
traj_per_sample=256,
crop_size=None,
sample_len=-1,
only_first_n_samples=-1,
rgbd_input=False,
):
super(DynamicReplicaDataset, self).__init__()
self.root = root
self.sample_len = sample_len
self.split = split
self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input
self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = []
with gzip.open(
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(
zipfile, List[DynamicReplicaFrameAnnotation]
)
seq_annot = defaultdict(list)
for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left":
seq_annot[frame_annot.sequence_name].append(frame_annot)
for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
self.sample_list.append(sample)
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
break
def __len__(self):
return len(self.sample_list)
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
for rgb in rgbs
]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
def __getitem__(self, index):
sample = self.sample_list[index]
T = len(sample)
rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size
image_size = (H, W)
for i in range(T):
traj_path = os.path.join(
self.root, self.split, sample[i].trajectories["path"]
)
traj = torch.load(traj_path)
visibilities.append(traj["verts_inds_vis"].numpy())
rgbs.append(traj["img"].numpy())
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities)
T, N, D = traj_2d.shape
# subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape
image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData(
video=video,
trajectory=traj_2d,
visibility=visibility,
valid=torch.ones(T, N),
seq_name=sample[0].sequence_name,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import cv2
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
from cotracker.models.core.model_utils import smart_cat
class CoTrackerDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_last_frame=False,
use_augs=False,
):
super(CoTrackerDataset, self).__init__()
np.random.seed(0)
torch.manual_seed(0)
self.data_root = data_root
self.seq_len = seq_len
self.traj_per_sample = traj_per_sample
self.sample_vis_last_frame = sample_vis_last_frame
self.use_augs = use_augs
self.crop_size = crop_size
# photometric augmentation
self.photo_aug = ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25
self.color_aug_prob = 0.25
# occlusion augmentation
self.eraser_aug_prob = 0.5
self.eraser_bounds = [2, 100]
self.eraser_max = 10
# occlusion augmentation
self.replace_aug_prob = 0.5
self.replace_bounds = [2, 100]
self.replace_max = 10
# spatial augmentations
self.pad_bounds = [0, 100]
self.crop_size = crop_size
self.resize_lim = [0.25, 2.0] # sample resizes from here
self.resize_delta = 0.2
self.max_crop_offset = 50
self.do_flip = True
self.h_flip_prob = 0.5
self.v_flip_prob = 0.5
def getitem_helper(self, index):
return NotImplementedError
def __getitem__(self, index):
gotit = False
sample, gotit = self.getitem_helper(index)
if not gotit:
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros(
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
# dataset_name="kubric",
)
return sample, gotit
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
if eraser:
############ eraser transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
for i in range(1, S):
if np.random.rand() < self.eraser_aug_prob:
for _ in range(
np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
dy = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
)
rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace:
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs_alt
]
############ replace transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
for i in range(1, S):
if np.random.rand() < self.replace_aug_prob:
for _ in range(
np.random.randint(1, self.replace_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
dy = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
wid = x1 - x0
hei = y1 - y0
y00 = np.random.randint(0, H - hei)
x00 = np.random.randint(0, W - wid)
fr = np.random.randint(0, S)
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
rgbs[i][y0:y1, x0:x1, :] = rep
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug
rgbs = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur
rgbs = [
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
return rgbs, trajs, visibles
def add_spatial_augs(self, rgbs, trajs, visibles, crop_size):
T, N, __ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
############ spatial transform ############
# padding
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
]
trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2]
# scaling + stretching
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
scale_x = scale
scale_y = scale
H_new = H
W_new = W
scale_delta_x = 0.0
scale_delta_y = 0.0
rgbs_scaled = []
for s in range(S):
if s == 1:
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
elif s > 1:
scale_delta_x = (
scale_delta_x * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_delta_y = (
scale_delta_y * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_x = scale_x + scale_delta_x
scale_y = scale_y + scale_delta_y
# bring h/w closer
scale_xy = (scale_x + scale_y) * 0.5
scale_x = scale_x * 0.5 + scale_xy * 0.5
scale_y = scale_y * 0.5 + scale_xy * 0.5
# don't get too crazy
scale_x = np.clip(scale_x, 0.2, 2.0)
scale_y = np.clip(scale_y, 0.2, 2.0)
H_new = int(H * scale_y)
W_new = int(W * scale_x)
# make it at least slightly bigger than the crop area,
# so that the random cropping can add diversity
H_new = np.clip(H_new, crop_size[0] + 10, None)
W_new = np.clip(W_new, crop_size[1] + 10, None)
# recompute scale in case we clipped
scale_x = (W_new - 1) / float(W - 1)
scale_y = (H_new - 1) / float(H - 1)
rgbs_scaled.append(
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
)
trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled
ok_inds = visibles[0, :] > 0
vis_trajs = trajs[:, ok_inds] # S,?,2
if vis_trajs.shape[1] > 0:
mid_x = np.mean(vis_trajs[0, :, 0])
mid_y = np.mean(vis_trajs[0, :, 1])
else:
mid_y = crop_size[0]
mid_x = crop_size[1]
x0 = int(mid_x - crop_size[1] // 2)
y0 = int(mid_y - crop_size[0] // 2)
offset_x = 0
offset_y = 0
for s in range(S):
# on each frame, shift a bit more
if s == 1:
offset_x = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
offset_y = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
elif s > 1:
offset_x = int(
offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
)
offset_y = int(
offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
)
x0 = x0 + offset_x
y0 = y0 + offset_y
H_new, W_new = rgbs[s].shape[:2]
if H_new == crop_size[0]:
y0 = 0
else:
y0 = min(max(0, y0), H_new - crop_size[0] - 1)
if W_new == crop_size[1]:
x0 = 0
else:
x0 = min(max(0, x0), W_new - crop_size[1] - 1)
rgbs[s] = rgbs[s][y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
trajs[s, :, 0] -= x0
trajs[s, :, 1] -= y0
H_new = crop_size[0]
W_new = crop_size[1]
# flip
h_flipped = False
v_flipped = False
if self.do_flip:
# h flip
if np.random.rand() < self.h_flip_prob:
h_flipped = True
rgbs = [rgb[:, ::-1] for rgb in rgbs]
# v flip
if np.random.rand() < self.v_flip_prob:
v_flipped = True
rgbs = [rgb[::-1] for rgb in rgbs]
if h_flipped:
trajs[:, :, 0] = W_new - trajs[:, :, 0]
if v_flipped:
trajs[:, :, 1] = H_new - trajs[:, :, 1]
return np.stack(rgbs), trajs
def crop(self, rgbs, trajs, crop_size):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
############ spatial transform ############
H_new = H
W_new = W
# simple random crop
y0 = 0 if crop_size[0] >= H_new else (H_new - crop_size[0]) // 2
# np.random.randint(0,
x0 = 0 if crop_size[1] >= W_new else np.random.randint(0, W_new - crop_size[1])
rgbs = [rgb[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return np.stack(rgbs), trajs
class KubricMovifDataset(CoTrackerDataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_last_frame=False,
use_augs=False,
random_seq_len=False,
random_frame_rate=False,
random_number_traj=False,
split="train",
):
super(KubricMovifDataset, self).__init__(
data_root=data_root,
crop_size=crop_size,
seq_len=seq_len,
traj_per_sample=traj_per_sample,
sample_vis_last_frame=sample_vis_last_frame,
use_augs=use_augs,
)
self.random_seq_len = random_seq_len
self.random_frame_rate = random_frame_rate
self.random_number_traj = random_number_traj
self.pad_bounds = [0, 25]
self.resize_lim = [0.75, 1.25] # sample resizes from here
self.resize_delta = 0.05
self.max_crop_offset = 15
self.split = split
self.seq_names = [
fname
for fname in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, fname))
]
if self.split == "valid":
self.seq_names = self.seq_names[:30]
assert use_augs == False
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def getitem_helper(self, index):
gotit = True
seq_name = self.seq_names[index]
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
rgb_path = os.path.join(self.data_root, seq_name, "frames")
img_paths = sorted(os.listdir(rgb_path))
rgbs = []
for i, img_path in enumerate(img_paths):
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
rgbs = np.stack(rgbs)
annot_dict = np.load(npy_path, allow_pickle=True).item()
traj_2d = annot_dict["coords"]
visibility = annot_dict["visibility"]
frame_rate = 1
final_num_traj = self.traj_per_sample
crop_size = self.crop_size
# random crop
min_num_traj = 1
assert self.traj_per_sample >= min_num_traj
if self.random_seq_len and self.random_number_traj:
final_num_traj = np.random.randint(min_num_traj, self.traj_per_sample)
alpha = final_num_traj / float(self.traj_per_sample)
seq_len = int(alpha * 10 + (1 - alpha) * self.seq_len)
seq_len = np.random.randint(seq_len - 2, seq_len + 2)
if self.random_frame_rate:
frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
elif self.random_number_traj:
final_num_traj = np.random.randint(min_num_traj, self.traj_per_sample)
alpha = final_num_traj / float(self.traj_per_sample)
seq_len = 8 * int(alpha * 2 + (1 - alpha) * self.seq_len // 8)
# seq_len = np.random.randint(seq_len , seq_len + 2)
if self.random_frame_rate:
frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
elif self.random_seq_len:
seq_len = np.random.randint(int(self.seq_len / 2), self.seq_len)
if self.random_frame_rate:
frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
else:
seq_len = self.seq_len
if self.random_frame_rate:
frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0))
no_augs = False
if seq_len < len(rgbs):
if seq_len * frame_rate < len(rgbs):
start_ind = np.random.choice(len(rgbs) - (seq_len * frame_rate), 1)[0]
else:
start_ind = 0
rgbs = rgbs[start_ind : start_ind + seq_len * frame_rate : frame_rate]
traj_2d = traj_2d[start_ind : start_ind + seq_len * frame_rate : frame_rate]
visibility = visibility[
start_ind : start_ind + seq_len * frame_rate : frame_rate
]
assert seq_len <= len(rgbs)
if not no_augs:
if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(
rgbs, traj_2d, visibility, replace=False
)
rgbs, traj_2d = self.add_spatial_augs(
rgbs, traj_2d, visibility, crop_size
)
else:
rgbs, traj_2d = self.crop(rgbs, traj_2d, crop_size)
visibility[traj_2d[:, :, 0] > crop_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > crop_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
visibility = torch.from_numpy(visibility)
traj_2d = torch.from_numpy(traj_2d)
crop_tensor = torch.tensor(crop_size).flip(0)[None, None] / 2.0
close_pts_inds = torch.all(
torch.linalg.vector_norm(traj_2d[..., :2] - crop_tensor, dim=-1) < 1000.0,
dim=0,
)
traj_2d = traj_2d[:, close_pts_inds]
visibility = visibility[:, close_pts_inds]
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
visibile_pts_mid_frame_inds = (visibility[seq_len // 2]).nonzero(
as_tuple=False
)[:, 0]
visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
)
if self.sample_vis_last_frame:
visibile_pts_last_frame_inds = (visibility[seq_len - 1]).nonzero(
as_tuple=False
)[:, 0]
visibile_pts_inds = torch.cat(
(visibile_pts_inds, visibile_pts_last_frame_inds), dim=0
)
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
if len(point_inds) < self.traj_per_sample:
gotit = False
visible_inds_sampled = visibile_pts_inds[point_inds]
trajs = traj_2d[:, visible_inds_sampled].float()
visibles = visibility[:, visible_inds_sampled]
valids = torch.ones_like(visibles)
trajs = trajs[:, :final_num_traj]
visibles = visibles[:, :final_num_traj]
valids = valids[:, :final_num_traj]
rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()
sample = CoTrackerData(
video=rgbs,
trajectory=trajs,
visibility=visibles,
valid=valids,
seq_name=seq_name,
)
return sample, gotit
def __len__(self):
return len(self.seq_names)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import json
import cv2
import math
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
from cotracker.models.core.model_utils import smart_cat
from torchvision.io import read_video
import torchvision
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
import torchvision.transforms.functional as F
class RealDataset(torch.utils.data.Dataset):
def __init__(
self,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
random_frame_rate=False,
random_seq_len=False,
data_splits=[0],
random_resize=False,
limit_samples=10000,
):
super(RealDataset, self).__init__()
np.random.seed(0)
torch.manual_seed(0)
raise ValueError(f"This dataset wasn't released. You should collect your own dataset of real videos before training with this dataset class.")
stopwords = set(
[
"river",
"water",
"shore",
"lake",
"sea",
"ocean",
"silhouette",
"matte",
"online",
"virtual",
"meditation",
"artwork",
"drawing",
"animation",
"abstract",
"background",
"concept",
"cartoon",
"symbolic",
"painting",
"sketch",
"fireworks",
"fire",
"sky",
"darkness",
"timelapse",
"time-lapse",
"cgi",
"computer",
"computer-generated",
"drawing",
"draw",
"cgi",
"animate",
"cartoon",
"static",
"abstract",
"abstraction",
"3d",
"fandom",
"fantasy",
"graphics",
"cell",
"holographic",
"generated",
"generation" "telephoto",
"animated",
"disko",
"generate" "2d",
"3d",
"geometric",
"geometry",
"render",
"rendering",
"timelapse",
"slomo",
"slo",
"wallpaper",
"pattern",
"tile",
"generated",
"chroma",
"www",
"http",
"cannabis",
"loop",
"cycle",
"alpha",
"abstract",
"concept",
"digital",
"graphic",
"skies",
"fountain",
"train",
"rapid",
"fast",
"quick",
"vfx",
"effect",
]
)
def no_stopwords_in_key(key, stopwords):
for s in stopwords:
if s in key.split(","):
return False
return True
filelist_all = []
for part in data_splits:
filelist = np.load('YOUR FILELIST')
captions = np.load('YOUR CAPTIONS')
keywords = np.load('YOUR KEYWORDS')
filtered_seqs_motion = [
i
for i, key in enumerate(keywords)
if "motion" in key.split(",")
and (
"man" in key.split(",")
or "woman" in key.split(",")
or "animal" in key.split(",")
or "child" in key.split(",")
)
and no_stopwords_in_key(key, stopwords)
]
print("filtered_seqs_motion", len(filtered_seqs_motion))
filtered_seqs = filtered_seqs_motion
print(f"filtered_seqs {part}", len(filtered_seqs))
filelist_all = filelist_all + filelist[filtered_seqs].tolist()
if len(filelist_all) > limit_samples:
break
self.filelist = filelist_all[:limit_samples]
print(f"found {len(self.filelist)} unique videos")
self.traj_per_sample = traj_per_sample
self.crop_size = crop_size
self.seq_len = seq_len
self.random_frame_rate = random_frame_rate
self.random_resize = random_resize
self.random_seq_len = random_seq_len
def crop(self, rgbs):
S = len(rgbs)
H, W = rgbs.shape[2:]
H_new = H
W_new = W
# simple random crop
y0 = (
0
if self.crop_size[0] >= H_new
else np.random.randint(0, H_new - self.crop_size[0])
)
x0 = (
0
if self.crop_size[1] >= W_new
else np.random.randint(0, W_new - self.crop_size[1])
)
rgbs = [
rgb[:, y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
for rgb in rgbs
]
return torch.stack(rgbs)
def __getitem__(self, index):
gotit = False
sample, gotit = self.getitem_helper(index)
if not gotit:
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros(
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
),
trajectory=torch.ones(1, 1, 1, 2),
visibility=torch.ones(1, 1, 1),
valid=torch.ones(1, 1, 1),
)
return sample, gotit
def sample_h_w(self):
area = np.random.uniform(0.6, 1)
a1 = np.random.uniform(area, 1)
a2 = np.random.uniform(area, 1)
h = (a1 + a2) / 2.0
w = area / h
return h, w
def getitem_helper(self, index):
gotit = True
video_path = self.filelist[index]
rgbs, _, _ = read_video(str(video_path), output_format="TCHW", pts_unit="sec")
if rgbs.numel() == 0:
return None, False
seq_name = video_path
frame_rate = 1
if self.random_seq_len:
seq_len = np.random.randint(int(self.seq_len / 2), self.seq_len)
else:
seq_len = self.seq_len
while len(rgbs) < seq_len:
rgbs = torch.cat([rgbs, rgbs.flip(0)])
if seq_len < 8:
print("seq_len < 8, return NONE")
return None, False
if self.random_frame_rate:
max_frame_rate = min(4, int((len(rgbs) / seq_len)))
if max_frame_rate > 1:
frame_rate = np.random.randint(1, max_frame_rate)
if seq_len * frame_rate < len(rgbs):
start_ind = np.random.choice(len(rgbs) - (seq_len * frame_rate), 1)[0]
else:
start_ind = 0
rgbs = rgbs[start_ind : start_ind + seq_len * frame_rate : frame_rate]
assert seq_len <= len(rgbs)
if self.random_resize and np.random.rand() < 0.5:
video = []
rgbs = rgbs.permute(0, 2, 3, 1).numpy()
for i in range(len(rgbs)):
rgb = cv2.resize(
rgbs[i],
(self.crop_size[1], self.crop_size[0]),
interpolation=cv2.INTER_LINEAR,
)
video.append(rgb)
video = torch.tensor(np.stack(video)).permute(0, 3, 1, 2)
else:
video = self.crop(rgbs)
sample = CoTrackerData(
video=video,
trajectory=torch.ones(seq_len, self.traj_per_sample, 2),
visibility=torch.ones(seq_len, self.traj_per_sample),
valid=torch.ones(seq_len, self.traj_per_sample),
seq_name=seq_name,
)
return sample, gotit
def __len__(self):
return len(self.filelist)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import io
import glob
import torch
import pickle
import numpy as np
import mediapy as media
import random
from PIL import Image
from typing import Mapping, Tuple, Union
from cotracker.datasets.utils import CoTrackerData
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""Resize a video to output_size."""
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
# such as a jitted jax.image.resize. It will make things faster.
return media.resize_video(video, output_size)
def sample_queries_first(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first
visible point in each track as the query.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
"""
valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :]
query_points = []
for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)
return {
"video": frames[np.newaxis, ...],
"query_points": query_points[np.newaxis, ...],
"target_points": target_points[np.newaxis, ...],
"occluded": target_occluded[np.newaxis, ...],
}
def sample_queries_strided(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
query_stride: int = 5,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible
at the selected frames.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
has floats scaled to the range [-1, 1].
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization.
"""
tracks = []
occs = []
queries = []
trackgroups = []
total = 0
trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride):
mask = target_occluded[:, i] == 0
query = np.stack(
[
i * np.ones(target_occluded.shape[0:1]),
target_points[:, i, 1],
target_points[:, i, 0],
],
axis=-1,
)
queries.append(query[mask])
tracks.append(target_points[mask])
occs.append(target_occluded[mask])
trackgroups.append(trackgroup[mask])
total += np.array(np.sum(target_occluded[:, i] == 0))
return {
"video": frames[np.newaxis, ...],
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
}
class TapVidDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
dataset_type="davis",
resize_to=[256, 256],
queried_first=True,
fast_eval=False,
):
local_random = random.Random()
local_random.seed(42)
self.fast_eval = fast_eval
self.dataset_type = dataset_type
self.resize_to = resize_to
self.queried_first = queried_first
if self.dataset_type == "kinetics":
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
points_dataset = []
for pickle_path in all_paths:
with open(pickle_path, "rb") as f:
data = pickle.load(f)
points_dataset = points_dataset + data
if fast_eval:
points_dataset = local_random.sample(points_dataset, 50)
self.points_dataset = points_dataset
elif self.dataset_type == "robotap":
all_paths = glob.glob(os.path.join(data_root, "robotap_split*.pkl"))
points_dataset = None
for pickle_path in all_paths:
with open(pickle_path, "rb") as f:
data = pickle.load(f)
if points_dataset is None:
points_dataset = dict(data)
else:
points_dataset.update(data)
if fast_eval:
points_dataset_keys = local_random.sample(
sorted(points_dataset.keys()), 50
)
points_dataset = {k: points_dataset[k] for k in points_dataset_keys}
self.points_dataset = points_dataset
self.video_names = list(self.points_dataset.keys())
else:
with open(data_root, "rb") as f:
self.points_dataset = pickle.load(f)
if self.dataset_type == "davis":
self.video_names = list(self.points_dataset.keys())
elif self.dataset_type == "stacking":
# print("self.points_dataset", self.points_dataset)
self.video_names = [i for i in range(len(self.points_dataset))]
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
def __getitem__(self, index):
if self.dataset_type == "davis" or self.dataset_type == "robotap":
video_name = self.video_names[index]
else:
video_name = index
video = self.points_dataset[video_name]
frames = video["video"]
if self.fast_eval and frames.shape[0] > 300:
return self.__getitem__((index + 1) % self.__len__())
if isinstance(frames[0], bytes):
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
def decode(frame):
byteio = io.BytesIO(frame)
img = Image.open(byteio)
return np.array(img)
frames = np.array([decode(frame) for frame in frames])
target_points = self.points_dataset[video_name]["points"]
if self.resize_to is not None:
frames = resize_video(frames, self.resize_to)
target_points *= np.array(
[self.resize_to[1] - 1, self.resize_to[0] - 1]
) # 1 should be mapped to resize_to-1
else:
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first:
converted = sample_queries_first(target_occ, target_points, frames)
else:
converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = (
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
) # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
0
].permute(
1, 0
) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData(
rgbs,
trajs,
visibles,
seq_name=str(video_name),
query_points=query_points,
)
def __len__(self):
return len(self.points_dataset)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional, Dict
@dataclass(eq=False)
class CoTrackerData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
transforms: Optional[Dict[str, Any]] = None
aug_video: Optional[torch.Tensor] = None
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = segmentation = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
segmentation=segmentation,
seq_name=seq_name,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
query_points = transforms = aug_video = None
if batch[0][0].query_points is not None:
query_points = torch.stack([b.query_points for b, _ in batch], dim=0)
if batch[0][0].transforms is not None:
transforms = [b.transforms for b, _ in batch]
if batch[0][0].aug_video is not None:
aug_video = torch.stack([b.aug_video for b, _ in batch], dim=0)
return (
CoTrackerData(
video=video,
trajectory=trajectory,
visibility=visibility,
valid=valid,
seq_name=seq_name,
query_points=query_points,
aug_video=aug_video,
transforms=transforms,
),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: dynamic_replica
\ No newline at end of file
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_first
\ No newline at end of file
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_strided
\ No newline at end of file
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_kinetics_first
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment