Unverified Commit 673838f5 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Removing prototype related things from release/0.14 branch (#6687)

* Remove test related to prototype

* Remove torchvision/prototype dir

* Remove references/depth/stereo because it depend on prototype

* Remove prototype related entries on mypy.ini

* Remove things related to prototype in pytest.ini

* clean setup.py from prototype

* Clean CI from prototype

* Remove unused expect file
parent 07ae61bf
import enum
import functools
import pathlib
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datasets import VOCDetection
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
)
from torchvision.prototype.features import BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
NAME = "voc"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class VOC(Dataset):
"""
- **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2012",
task: str = "detection",
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
else:
self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._task = self._verify_str_arg(task, "task", ("detection", "segmentation"))
self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass"
self._split_folder = "Main" if task == "detection" else "Segmentation"
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
"2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"),
"2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"),
"2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"),
"2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"),
}
_TEST_ARCHIVES = {
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}
def _resources(self) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)
return [archive]
def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool:
path = pathlib.Path(data[0])
return name in path.parent.parts[-depth:]
class _Demux(enum.IntEnum):
SPLIT = 0
IMAGES = 1
ANNS = 2
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
if self._is_in_folder(data, name="ImageSets", depth=2):
return self._Demux.SPLIT
elif self._is_in_folder(data, name="JPEGImages"):
return self._Demux.IMAGES
elif self._is_in_folder(data, name=self._anns_folder):
return self._Demux.ANNS
else:
return None
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
instances = anns["object"]
return dict(
bounding_boxes=BoundingBox(
[
[int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for instance in instances
],
format="xyxy",
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self._categories.index(instance["name"]) for instance in instances], categories=self._categories
),
)
def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return dict(segmentation=EncodedImage.from_file(buffer))
def _prepare_sample(
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
return dict(
(self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
dp = IterKeyZipper(
dp,
data_dp,
key_fn=getitem(*[0] * level, 1),
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2007", "detection"): 2_501,
("train", "2007", "segmentation"): 209,
("train", "2008", "detection"): 2_111,
("train", "2008", "segmentation"): 511,
("train", "2009", "detection"): 3_473,
("train", "2009", "segmentation"): 749,
("train", "2010", "detection"): 4_998,
("train", "2010", "segmentation"): 964,
("train", "2011", "detection"): 5_717,
("train", "2011", "segmentation"): 1_112,
("train", "2012", "detection"): 5_717,
("train", "2012", "segmentation"): 1_464,
("val", "2007", "detection"): 2_510,
("val", "2007", "segmentation"): 213,
("val", "2008", "detection"): 2_221,
("val", "2008", "segmentation"): 512,
("val", "2009", "detection"): 3_581,
("val", "2009", "segmentation"): 750,
("val", "2010", "detection"): 5_105,
("val", "2010", "segmentation"): 964,
("val", "2011", "detection"): 5_823,
("val", "2011", "segmentation"): 1_111,
("val", "2012", "detection"): 5_823,
("val", "2012", "segmentation"): 1_449,
("trainval", "2007", "detection"): 5_011,
("trainval", "2007", "segmentation"): 422,
("trainval", "2008", "detection"): 4_332,
("trainval", "2008", "segmentation"): 1_023,
("trainval", "2009", "detection"): 7_054,
("trainval", "2009", "segmentation"): 1_499,
("trainval", "2010", "detection"): 10_103,
("trainval", "2010", "segmentation"): 1_928,
("trainval", "2011", "detection"): 11_540,
("trainval", "2011", "segmentation"): 2_223,
("trainval", "2012", "detection"): 11_540,
("trainval", "2012", "segmentation"): 2_913,
("test", "2007", "detection"): 4_952,
("test", "2007", "segmentation"): 210,
}[(self._split, self._year, self._task)]
def _filter_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == self._Demux.ANNS
def _generate_categories(self) -> List[str]:
self._task = "detection"
resources = self._resources()
archive_dp = resources[0].load(self._root)
dp = Filter(archive_dp, self._filter_anns)
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
categories = sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
# We add a background category to be used during segmentation
categories.insert(0, "__background__")
return categories
import functools
import os
import os.path
import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import EncodedData, EncodedImage, Label
__all__ = ["from_data_folder", "from_image_folder"]
def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
rel_path = pathlib.Path(path).relative_to(root)
return rel_path.is_dir() or rel_path.parent != pathlib.Path(".")
def _prepare_sample(
data: Tuple[str, BinaryIO],
*,
root: pathlib.Path,
categories: List[str],
) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).relative_to(root).parts[0]
return dict(
path=path,
data=EncodedData.from_file(buffer),
label=Label.from_category(category, categories=categories),
)
def from_data_folder(
root: Union[str, pathlib.Path],
*,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
) -> Tuple[IterDataPipe, List[str]]:
root = pathlib.Path(root).expanduser().resolve()
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = FileOpener(dp, mode="rb")
return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories
def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = EncodedImage(sample.pop("data").data)
return sample
def from_image_folder(
root: Union[str, pathlib.Path],
*,
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
**kwargs: Any,
) -> Tuple[IterDataPipe, List[str]]:
valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())]
dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs)
return Mapper(dp, _data_to_image_key), categories
import os
from typing import Optional
import torchvision._internally_replaced_utils as _iru
def home(root: Optional[str] = None) -> str:
if root is not None:
_iru._HOME = root
return _iru._HOME
root = os.getenv("TORCHVISION_DATASETS_HOME")
if root is not None:
return root
return _iru._HOME
def use_sharded_dataset(use: Optional[bool] = None) -> bool:
if use is not None:
_iru._USE_SHARDED_DATASETS = use
return _iru._USE_SHARDED_DATASETS
use = os.getenv("TORCHVISION_SHARDED_DATASETS")
if use is not None:
return use == "1"
return _iru._USE_SHARDED_DATASETS
# type: ignore
import argparse
import collections.abc
import contextlib
import inspect
import itertools
import os
import os.path
import pathlib
import shutil
import sys
import tempfile
import time
import unittest.mock
import warnings
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader_experimental import DataLoader2
from torchvision import datasets as legacy_datasets
from torchvision.datasets.utils import extract_archive
from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import PILToTensor
def main(
name,
*,
variant=None,
legacy=True,
new=True,
start=True,
iteration=True,
num_starts=3,
num_samples=10_000,
temp_root=None,
num_workers=0,
):
benchmarks = [
benchmark
for benchmark in DATASET_BENCHMARKS
if benchmark.name == name and (variant is None or benchmark.variant == variant)
]
if not benchmarks:
msg = f"No DatasetBenchmark available for dataset '{name}'"
if variant is not None:
msg += f" and variant '{variant}'"
raise ValueError(msg)
for benchmark in benchmarks:
print("#" * 80)
print(f"{benchmark.name}" + (f" ({benchmark.variant})" if benchmark.variant is not None else ""))
if legacy and start:
print(
"legacy",
"cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
)
print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts),
)
if legacy and iteration:
print(
"legacy",
"iteration",
Measurement.iterations_per_time(
benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples)
),
)
if new and start:
print(
"new",
"cold_start",
Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts),
)
if new and iteration:
print(
"new",
"iteration",
Measurement.iterations_per_time(
benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)
),
)
class DatasetBenchmark:
def __init__(
self,
name: str,
*,
variant=None,
legacy_cls=None,
new_config=None,
legacy_config_map=None,
legacy_special_options_map=None,
prepare_legacy_root=None,
):
self.name = name
self.variant = variant
self.new_raw_dataset = new_datasets._api.find(name)
self.legacy_cls = legacy_cls or self._find_legacy_cls()
if new_config is None:
new_config = self.new_raw_dataset.default_config
elif isinstance(new_config, dict):
new_config = self.new_raw_dataset.info.make_config(**new_config)
self.new_config = new_config
self.legacy_config_map = legacy_config_map
self.legacy_special_options_map = legacy_special_options_map or self._legacy_special_options_map
self.prepare_legacy_root = prepare_legacy_root
def new_dataset(self, *, num_workers=0):
return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers)
def new_cold_start(self, *, num_workers):
def fn(timer):
with timer:
dataset = self.new_dataset(num_workers=num_workers)
next(iter(dataset))
return fn
def new_iteration(self, *, num_samples, num_workers):
def fn(timer):
dataset = self.new_dataset(num_workers=num_workers)
num_sample = 0
with timer:
for _ in dataset:
num_sample += 1
if num_sample == num_samples:
break
return num_sample
return fn
def suppress_output(self):
@contextlib.contextmanager
def context_manager():
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
yield
return context_manager()
def legacy_dataset(self, root, *, num_workers=0, download=None):
legacy_config = self.legacy_config_map(self, root) if self.legacy_config_map else dict()
special_options = self.legacy_special_options_map(self)
if "download" in special_options and download is not None:
special_options["download"] = download
with self.suppress_output():
return DataLoader(
self.legacy_cls(legacy_config.pop("root", str(root)), **legacy_config, **special_options),
shuffle=True,
num_workers=num_workers,
)
@contextlib.contextmanager
def patch_download_and_integrity_checks(self):
patches = [
("download_url", dict()),
("download_file_from_google_drive", dict()),
("check_integrity", dict(new=lambda path, md5=None: os.path.isfile(path))),
]
dataset_module = sys.modules[self.legacy_cls.__module__]
utils_module = legacy_datasets.utils
with contextlib.ExitStack() as stack:
for name, patch_kwargs in patches:
patch_module = dataset_module if name in dir(dataset_module) else utils_module
stack.enter_context(unittest.mock.patch(f"{patch_module.__name__}.{name}", **patch_kwargs))
yield stack
def _find_resource_file_names(self):
info = self.new_raw_dataset.info
valid_options = info._valid_options
file_names = set()
for options in (
dict(zip(valid_options.keys(), values)) for values in itertools.product(*valid_options.values())
):
resources = self.new_raw_dataset.resources(info.make_config(**options))
file_names.update([resource.file_name for resource in resources])
return file_names
@contextlib.contextmanager
def legacy_root(self, temp_root):
new_root = pathlib.Path(new_datasets.home()) / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
warnings.warn(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories.",
RuntimeWarning,
)
try:
for file_name in self._find_resource_file_names():
(legacy_root / file_name).symlink_to(new_root / file_name)
if self.prepare_legacy_root:
self.prepare_legacy_root(self, legacy_root)
with self.patch_download_and_integrity_checks():
yield legacy_root
finally:
shutil.rmtree(legacy_root)
def legacy_cold_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers)
next(iter(dataset))
return fn
def legacy_warm_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
self.legacy_dataset(root, num_workers=num_workers)
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers, download=False)
next(iter(dataset))
return fn
def legacy_iteration(self, temp_root, *, num_samples, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer:
for num_sample, _ in enumerate(dataset, 1):
if num_sample == num_samples:
break
return num_sample
return fn
def _find_legacy_cls(self):
legacy_clss = {
name.lower(): dataset_class
for name, dataset_class in legacy_datasets.__dict__.items()
if isinstance(dataset_class, type) and issubclass(dataset_class, legacy_datasets.VisionDataset)
}
try:
return legacy_clss[self.name]
except KeyError as error:
raise RuntimeError(
f"Can't determine the legacy dataset class for '{self.name}' automatically. "
f"Please set the 'legacy_cls' keyword argument manually."
) from error
_SPECIAL_KWARGS = {
"transform",
"target_transform",
"transforms",
"download",
}
@staticmethod
def _legacy_special_options_map(benchmark):
available_parameters = set()
for cls in benchmark.legacy_cls.__mro__:
if cls is legacy_datasets.VisionDataset:
break
available_parameters.update(inspect.signature(cls.__init__).parameters)
available_special_kwargs = benchmark._SPECIAL_KWARGS.intersection(available_parameters)
special_options = dict()
if "download" in available_special_kwargs:
special_options["download"] = True
if "transform" in available_special_kwargs:
special_options["transform"] = PILToTensor()
if "target_transform" in available_special_kwargs:
special_options["target_transform"] = torch.tensor
elif "transforms" in available_special_kwargs:
special_options["transforms"] = JointTransform(PILToTensor(), PILToTensor())
return special_options
class Measurement:
@classmethod
def time(cls, fn, *, number):
results = Measurement._timeit(fn, number=number)
times = torch.tensor(tuple(zip(*results))[1])
return cls._format(times, unit="s")
@classmethod
def iterations_per_time(cls, fn):
num_samples, time = Measurement._timeit(fn, number=1)[0]
iterations_per_second = torch.tensor(num_samples) / torch.tensor(time)
return cls._format(iterations_per_second, unit="it/s")
class Timer:
def __init__(self):
self._start = None
self._stop = None
def __enter__(self):
self._start = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop = time.perf_counter()
@property
def delta(self):
if self._start is None:
raise RuntimeError()
elif self._stop is None:
raise RuntimeError()
return self._stop - self._start
@classmethod
def _timeit(cls, fn, number):
results = []
for _ in range(number):
timer = cls.Timer()
output = fn(timer)
results.append((output, timer.delta))
return results
@classmethod
def _format(cls, measurements, *, unit):
measurements = torch.as_tensor(measurements).to(torch.float64).flatten()
if measurements.numel() == 1:
# TODO format that into engineering format
return f"{float(measurements):.3f} {unit}"
mean, std = Measurement._compute_mean_and_std(measurements)
# TODO format that into engineering format
return f"{mean:.3f} ± {std:.3f} {unit}"
@classmethod
def _compute_mean_and_std(cls, t):
mean = float(t.mean())
std = float(t.std(0, unbiased=t.numel() > 1))
return mean, std
def no_split(benchmark, root):
legacy_config = dict(benchmark.new_config)
del legacy_config["split"]
return legacy_config
def bool_split(name="train"):
def legacy_config_map(benchmark, root):
legacy_config = dict(benchmark.new_config)
legacy_config[name] = legacy_config.pop("split") == "train"
return legacy_config
return legacy_config_map
def base_folder(rel_folder=None):
if rel_folder is None:
def rel_folder(benchmark):
return benchmark.name
elif not callable(rel_folder):
name = rel_folder
def rel_folder(_):
return name
def prepare_legacy_root(benchmark, root):
files = list(root.glob("*"))
folder = root / rel_folder(benchmark)
folder.mkdir(parents=True)
for file in files:
shutil.move(str(file), str(folder))
return folder
return prepare_legacy_root
class JointTransform:
def __init__(self, *transforms):
self.transforms = transforms
def __call__(self, *inputs):
if len(inputs) == 1 and isinstance(inputs, collections.abc.Sequence):
inputs = inputs[0]
if len(inputs) != len(self.transforms):
raise RuntimeError(
f"The number of inputs and transforms mismatches: {len(inputs)} != {len(self.transforms)}."
)
return tuple(transform(input) for transform, input in zip(self.transforms, inputs))
def caltech101_legacy_config_map(benchmark, root):
legacy_config = no_split(benchmark, root)
# The new dataset always returns the category and annotation
legacy_config["target_type"] = ("category", "annotation")
return legacy_config
mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw")
def mnist_legacy_config_map(benchmark, root):
return dict(train=benchmark.new_config.split == "train")
def emnist_prepare_legacy_root(benchmark, root):
folder = mnist_base_folder(benchmark, root)
shutil.move(str(folder / "emnist-gzip.zip"), str(folder / "gzip.zip"))
return folder
def emnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["split"] = benchmark.new_config.image_set.replace("_", "").lower()
return legacy_config
def qmnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["what"] = benchmark.new_config.split
# The new dataset always returns the full label
legacy_config["compat"] = False
return legacy_config
def coco_legacy_config_map(benchmark, root):
images, _ = benchmark.new_raw_dataset.resources(benchmark.new_config)
return dict(
root=str(root / pathlib.Path(images.file_name).stem),
annFile=str(
root / "annotations" / f"{benchmark.variant}_{benchmark.new_config.split}{benchmark.new_config.year}.json"
),
)
def coco_prepare_legacy_root(benchmark, root):
images, annotations = benchmark.new_raw_dataset.resources(benchmark.new_config)
extract_archive(str(root / images.file_name))
extract_archive(str(root / annotations.file_name))
DATASET_BENCHMARKS = [
DatasetBenchmark(
"caltech101",
legacy_config_map=caltech101_legacy_config_map,
prepare_legacy_root=base_folder(),
legacy_special_options_map=lambda config: dict(
download=True,
transform=PILToTensor(),
target_transform=JointTransform(torch.tensor, torch.tensor),
),
),
DatasetBenchmark(
"caltech256",
legacy_config_map=no_split,
prepare_legacy_root=base_folder(),
),
DatasetBenchmark(
"celeba",
prepare_legacy_root=base_folder(),
legacy_config_map=lambda benchmark: dict(
split="valid" if benchmark.new_config.split == "val" else benchmark.new_config.split,
# The new dataset always returns all annotations
target_type=("attr", "identity", "bbox", "landmarks"),
),
),
DatasetBenchmark(
"cifar10",
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"cifar100",
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"emnist",
prepare_legacy_root=emnist_prepare_legacy_root,
legacy_config_map=emnist_legacy_config_map,
),
DatasetBenchmark(
"fashionmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"kmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"mnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"qmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"sbd",
legacy_cls=legacy_datasets.SBDataset,
legacy_config_map=lambda benchmark: dict(
image_set=benchmark.new_config.split,
mode="boundaries" if benchmark.new_config.boundaries else "segmentation",
),
legacy_special_options_map=lambda benchmark: dict(
download=True,
transforms=JointTransform(
PILToTensor(), torch.tensor if benchmark.new_config.boundaries else PILToTensor()
),
),
),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet),
DatasetBenchmark(
"coco",
variant="instances",
legacy_cls=legacy_datasets.CocoDetection,
new_config=dict(split="train", annotations="instances"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
DatasetBenchmark(
"coco",
variant="captions",
legacy_cls=legacy_datasets.CocoCaptions,
new_config=dict(split="train", annotations="captions"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
]
def parse_args(argv=None):
parser = argparse.ArgumentParser(
prog="torchvision.prototype.datasets.benchmark.py",
description="Utility to benchmark new datasets against their legacy variants.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"--variant", help="Variant of the dataset. If omitted all available variants will be benchmarked."
)
parser.add_argument(
"-n",
"--num-starts",
type=int,
default=3,
help="Number of warm and cold starts of each benchmark. Default to 3.",
)
parser.add_argument(
"-N",
"--num-samples",
type=int,
default=10_000,
help="Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000.",
)
parser.add_argument(
"--nl",
"--no-legacy",
dest="legacy",
action="store_false",
help="Skip legacy benchmarks.",
)
parser.add_argument(
"--nn",
"--no-new",
dest="new",
action="store_false",
help="Skip new benchmarks.",
)
parser.add_argument(
"--ns",
"--no-start",
dest="start",
action="store_false",
help="Skip start benchmarks.",
)
parser.add_argument(
"--ni",
"--no-iteration",
dest="iteration",
action="store_false",
help="Skip iteration benchmarks.",
)
parser.add_argument(
"-t",
"--temp-root",
type=pathlib.Path,
help=(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser.add_argument(
"-j",
"--num-workers",
type=int,
default=0,
help=(
"Number of subprocesses used to load the data. Setting this to 0 (default) will load all data in the main "
"process and thus disable multi-processing."
),
)
return parser.parse_args(argv or sys.argv[1:])
if __name__ == "__main__":
args = parse_args()
try:
main(
args.name,
variant=args.variant,
legacy=args.legacy,
new=args.new,
start=args.start,
iteration=args.iteration,
num_starts=args.num_starts,
num_samples=args.num_samples,
temp_root=args.temp_root,
num_workers=args.num_workers,
)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
# type: ignore
import argparse
import csv
import sys
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
for name in names:
path = BUILTIN_DIR / f"{name}.categories"
if path.exists() and not force:
continue
dataset = datasets.load(name)
try:
categories = dataset._generate_categories()
except NotImplementedError:
continue
with open(path, "w") as file:
writer = csv.writer(file, lineterminator="\n")
for category in categories:
writer.writerow((category,) if isinstance(category, str) else category)
def parse_args(argv=None):
parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py")
parser.add_argument(
"names",
nargs="*",
type=str,
help="Names of datasets to generate category files for. If omitted, all datasets will be used.",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Force regeneration of category files.",
)
args = parser.parse_args(argv or sys.argv[1:])
if not args.names:
args.names = datasets.list_datasets()
return args
if __name__ == "__main__":
args = parse_args()
try:
main(*args.names, force=args.force)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
from . import _internal # usort: skip
from ._dataset import Dataset
from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource
import abc
import importlib
import pathlib
from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union
from torch.utils.data import IterDataPipe
from torchvision.datasets.utils import verify_str_arg
from ._resource import OnlineResource
class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC):
@staticmethod
def _verify_str_arg(
value: str,
arg: Optional[str] = None,
valid_values: Optional[Collection[str]] = None,
*,
custom_msg: Optional[str] = None,
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
def __init__(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
) -> None:
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from None
self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
]
self._dp = self._datapipe(resources)
def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp
@abc.abstractmethod
def _resources(self) -> List[OnlineResource]:
pass
@abc.abstractmethod
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
pass
@abc.abstractmethod
def __len__(self) -> int:
pass
def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError
import csv
import functools
import pathlib
import pickle
from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision.prototype.utils._internal import fromfile
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"read_mat",
"MappingIterator",
"getitem",
"path_accessor",
"path_comparator",
"read_flo",
"hint_sharding",
"hint_shuffling",
]
K = TypeVar("K")
D = TypeVar("D")
# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
try:
import scipy.io as sio
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
if isinstance(buffer, StreamWrapper):
buffer = buffer.file_obj
return sio.loadmat(buffer, **kwargs)
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
self.drop_key = drop_key
def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
for mapping in self.datapipe:
yield from iter(mapping.values() if self.drop_key else mapping.items())
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
for item in items:
obj = obj[item]
return obj
def getitem(*items: Any) -> Callable[[Any], Any]:
return functools.partial(_getitem_closure, items=items)
def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
for attr in attrs:
obj = getattr(obj, attr)
return obj
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
return _getattr_closure(path, attrs=name.split("."))
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
return getter(pathlib.Path(data[0]))
def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]:
if isinstance(getter, str):
getter = functools.partial(_path_attribute_accessor, name=getter)
return functools.partial(_path_accessor_closure, getter=getter)
def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool:
return accessor(data) == value
def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]:
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
class PicklerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None:
self.source_datapipe = source_datapipe
def __iter__(self) -> Iterator[Any]:
for _, fobj in self.source_datapipe:
data = pickle.load(fobj)
for _, d in enumerate(data):
yield d
class SharderDataPipe(torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe):
def __init__(self, source_datapipe: IterDataPipe) -> None:
super().__init__(source_datapipe)
self.rank = 0
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.apply_sharding(self.world_size, self.rank)
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_id = self.rank
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_id + worker_info.id * num_workers
num_workers *= worker_info.num_workers
self.apply_sharding(num_workers, worker_id)
yield from super().__iter__()
class TakerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe, num_take: int) -> None:
super().__init__()
self.source_datapipe = source_datapipe
self.num_take = num_take
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.world_size = dist.get_world_size()
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers *= worker_info.num_workers
# TODO: this is weird as it drops more elements than it should
num_take = self.num_take // num_workers
for i, data in enumerate(self.source_datapipe):
if i < num_take:
yield data
else:
break
def __len__(self) -> int:
num_take = self.num_take // self.world_size
if isinstance(self.source_datapipe, Sized):
if len(self.source_datapipe) < num_take:
num_take = len(self.source_datapipe)
# TODO: might be weird to not take `num_workers` into account
return num_take
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[str, Any]]:
dp = IoPathFileLister(root=root)
dp = SharderDataPipe(dp)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE)
dp = IoPathFileOpener(dp, mode="rb")
dp = PicklerDataPipe(dp)
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp
def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False)
def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]:
path = BUILTIN_DIR / f"{name}.categories"
with open(path, newline="") as file:
rows = list(csv.reader(file))
rows = [row[0] if len(row) == 1 else row for row in rows]
return rows
import abc
import hashlib
import itertools
import pathlib
from typing import Any, Callable, IO, NoReturn, Optional, Sequence, Set, Tuple, Union
from urllib.parse import urlparse
from torchdata.datapipes.iter import (
FileLister,
FileOpener,
IterableWrapper,
IterDataPipe,
RarArchiveLoader,
TarArchiveLoader,
ZipArchiveLoader,
)
from torchvision.datasets.utils import (
_decompress,
_detect_file_type,
_get_google_drive_file_id,
_get_redirect_url,
download_file_from_google_drive,
download_url,
extract_archive,
)
from typing_extensions import Literal
class OnlineResource(abc.ABC):
def __init__(
self,
*,
file_name: str,
sha256: Optional[str] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
if isinstance(preprocess, str):
if preprocess == "decompress":
preprocess = self._decompress
elif preprocess == "extract":
preprocess = self._extract
else:
raise ValueError(
f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string,"
f"but got {preprocess} instead."
)
self._preprocess = preprocess
@staticmethod
def _extract(file: pathlib.Path) -> None:
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
@staticmethod
def _decompress(file: pathlib.Path) -> None:
_decompress(str(file), remove_finished=True)
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
return FileOpener(FileLister(str(path), recursive=True), mode="rb")
dp = FileOpener(IterableWrapper((str(path),)), mode="rb")
archive_loader = self._guess_archive_loader(path)
if archive_loader:
dp = archive_loader(dp)
return dp
_ARCHIVE_LOADERS = {
".tar": TarArchiveLoader,
".zip": ZipArchiveLoader,
".rar": RarArchiveLoader,
}
def _guess_archive_loader(
self, path: pathlib.Path
) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]:
try:
_, archive_type, _ = _detect_file_type(path.name)
except RuntimeError:
return None
return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type]
def load(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False
) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root)
path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
stem = path.name.replace("".join(path.suffixes), "")
def find_candidates() -> Set[pathlib.Path]:
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
# test split of the stanford-cars dataset uses the files
# - cars_test.tgz
# - cars_test_annos_withlabels.mat
# Globbing for `"cars_test*"` picks up both.
candidates = {file for file in path.parent.glob(f"{stem}.*")}
folder_candidate = path.parent / stem
if folder_candidate.exists():
candidates.add(folder_candidate)
return candidates
candidates = find_candidates()
if not candidates:
self.download(root, skip_integrity_check=skip_integrity_check)
if self._preprocess is not None:
self._preprocess(path)
candidates = find_candidates()
# We use the path with the fewest suffixes. This gives us the
# extracted > decompressed > raw
# priority that we want for the best I/O performance.
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))
@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
pass
def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
root = pathlib.Path(root)
self._download(root)
path = root / self.file_name
if self.sha256 and not skip_integrity_check:
self._check_sha256(path)
return path
def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None:
hash = hashlib.sha256()
with open(path, "rb") as file:
for chunk in iter(lambda: file.read(chunk_size), b""):
hash.update(chunk)
sha256 = hash.hexdigest()
if sha256 != self.sha256:
raise RuntimeError(
f"After the download, the SHA256 checksum of {path} didn't match the expected one: "
f"{sha256} != {self.sha256}"
)
class HttpResource(OnlineResource):
def __init__(
self, url: str, *, file_name: Optional[str] = None, mirrors: Sequence[str] = (), **kwargs: Any
) -> None:
super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)
self.url = url
self.mirrors = mirrors
self._resolved = False
def resolve(self) -> OnlineResource:
if self._resolved:
return self
redirect_url = _get_redirect_url(self.url)
if redirect_url == self.url:
self._resolved = True
return self
meta = {
attr.lstrip("_"): getattr(self, attr)
for attr in (
"file_name",
"sha256",
"_preprocess",
)
}
gdrive_id = _get_google_drive_file_id(redirect_url)
if gdrive_id:
return GDriveResource(gdrive_id, **meta)
http_resource = HttpResource(redirect_url, **meta)
http_resource._resolved = True
return http_resource
def _download(self, root: pathlib.Path) -> None:
if not self._resolved:
return self.resolve()._download(root)
for url in itertools.chain((self.url,), self.mirrors):
try:
download_url(url, str(root), filename=self.file_name, md5=None)
# TODO: make this more precise
except Exception:
continue
return
else:
# TODO: make this more informative
raise RuntimeError("Download failed!")
class GDriveResource(OnlineResource):
def __init__(self, id: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.id = id
def _download(self, root: pathlib.Path) -> None:
download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None)
class ManualDownloadResource(OnlineResource):
def __init__(self, instructions: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.instructions = instructions
def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError(
f"The file {self.file_name} cannot be downloaded automatically. "
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)
class KaggleDownloadResource(ManualDownloadResource):
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
instructions = "\n".join(
(
"1. Register and login at https://www.kaggle.com",
f"2. Navigate to {challenge_url}",
"3. Click 'Join Competition' and follow the instructions there",
"4. Navigate to the 'Data' tab",
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
)
)
super().__init__(instructions, file_name=file_name, **kwargs)
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor
from ._image import (
ColorSpace,
Image,
ImageType,
ImageTypeJIT,
LegacyImageType,
LegacyImageTypeJIT,
TensorImageType,
TensorImageTypeJIT,
)
from ._label import Label, OneHotLabel
from ._mask import Mask
from __future__ import annotations
from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._feature import _Feature, FillTypeJIT
class BoundingBoxFormat(StrEnum):
XYXY = StrEnum.auto()
XYWH = StrEnum.auto()
CXCYWH = StrEnum.auto()
class BoundingBox(_Feature):
format: BoundingBoxFormat
image_size: Tuple[int, int]
def __new__(
cls,
data: Any,
*,
format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
bounding_box.format = format
bounding_box.image_size = image_size
return bounding_box
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, image_size=self.image_size)
@classmethod
def new_like(
cls,
other: BoundingBox,
data: Any,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
image_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
) -> BoundingBox:
return super().new_like(
other,
data,
format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size,
**kwargs,
)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
return BoundingBox.new_like(
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
)
def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)
def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> BoundingBox:
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
return BoundingBox.new_like(self, output, image_size=image_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, image_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.new_like(self, output, image_size=image_size)
def center_crop(self, output_size: List[int]) -> BoundingBox:
output, image_size = self._F.center_crop_bounding_box(
self, format=self.format, image_size=self.image_size, output_size=output_size
)
return BoundingBox.new_like(self, output, image_size=image_size)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> BoundingBox:
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
return BoundingBox.new_like(self, output, image_size=image_size)
def pad(
self,
padding: Union[int, Sequence[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> BoundingBox:
output, image_size = self._F.pad_bounding_box(
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
)
return BoundingBox.new_like(self, output, image_size=image_size)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output, image_size = self._F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
)
return BoundingBox.new_like(self, output, image_size=image_size)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
self,
self.format,
self.image_size,
angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
)
return BoundingBox.new_like(self, output, dtype=output.dtype)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
from __future__ import annotations
import os
import sys
from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from ._feature import _Feature
D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature):
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
@classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
with open(path, "rb") as file:
return cls.from_file(file, **kwargs)
class EncodedImage(EncodedData):
# TODO: Use @functools.cached_property if we can depend on Python 3.8
@property
def image_size(self) -> Tuple[int, int]:
if not hasattr(self, "_image_size"):
with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image:
self._image_size = image.height, image.width
return self._image_size
class EncodedVideo(EncodedData):
pass
from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torch._C import DisableTorchFunction
from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature")
FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None]
def is_simple_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, _Feature)
class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None
def __new__(
cls: Type[F],
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> F:
return (
torch.as_tensor( # type: ignore[return-value]
data,
dtype=dtype, # type: ignore[arg-type]
device=device, # type: ignore[arg-type]
)
.as_subclass(cls) # type: ignore[arg-type]
.requires_grad_(requires_grad)
)
@classmethod
def new_like(
cls: Type[F],
other: F,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
**kwargs: Any,
) -> F:
return cls(
data,
dtype=dtype if dtype is not None else other.dtype,
device=device if device is not None else other.device,
requires_grad=requires_grad if requires_grad is not None else other.requires_grad,
**kwargs,
)
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.new_like(
input, output, dtype=output.dtype, device=output.device
),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
"""For general information about how the __torch_function__ protocol works,
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature`
use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS`
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with DisableTorchFunction():
output = func(*args, **kwargs or dict())
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with
# `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `features.Image`.
if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output) # type: ignore[no-any-return]
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor) # type: ignore[arg-type]
return output
def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
# If that ever gets implemented, remove this in favor of the solution on the `torch.Tensor` class.
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
return f"{super().__repr__()[:-1]}, {extra_repr})"
@property
def _F(self) -> ModuleType:
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred
# until the first time we need reference to the functional module and it's shared across all instances of
# the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if _Feature.__F is None:
from ..transforms import functional
_Feature.__F = functional
return _Feature.__F
def horizontal_flip(self) -> _Feature:
return self
def vertical_flip(self) -> _Feature:
return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> _Feature:
return self
def crop(self, top: int, left: int, height: int, width: int) -> _Feature:
return self
def center_crop(self, output_size: List[int]) -> _Feature:
return self
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> _Feature:
return self
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> _Feature:
return self
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> _Feature:
return self
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> _Feature:
return self
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> _Feature:
return self
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> _Feature:
return self
def adjust_brightness(self, brightness_factor: float) -> _Feature:
return self
def adjust_saturation(self, saturation_factor: float) -> _Feature:
return self
def adjust_contrast(self, contrast_factor: float) -> _Feature:
return self
def adjust_sharpness(self, sharpness_factor: float) -> _Feature:
return self
def adjust_hue(self, hue_factor: float) -> _Feature:
return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> _Feature:
return self
def posterize(self, bits: int) -> _Feature:
return self
def solarize(self, threshold: float) -> _Feature:
return self
def autocontrast(self) -> _Feature:
return self
def equalize(self) -> _Feature:
return self
def invert(self) -> _Feature:
return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature:
return self
InputType = Union[torch.Tensor, PIL.Image.Image, _Feature]
InputTypeJIT = torch.Tensor
from __future__ import annotations
import warnings
from typing import Any, cast, List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.utils import draw_bounding_boxes, make_grid
from ._bounding_box import BoundingBox
from ._feature import _Feature, FillTypeJIT
class ColorSpace(StrEnum):
OTHER = StrEnum.auto()
GRAY = StrEnum.auto()
GRAY_ALPHA = StrEnum.auto()
RGB = StrEnum.auto()
RGB_ALPHA = StrEnum.auto()
@classmethod
def from_pil_mode(cls, mode: str) -> ColorSpace:
if mode == "L":
return cls.GRAY
elif mode == "LA":
return cls.GRAY_ALPHA
elif mode == "RGB":
return cls.RGB
elif mode == "RGBA":
return cls.RGB_ALPHA
else:
return cls.OTHER
@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
return _from_tensor_shape(shape)
def _from_tensor_shape(shape: List[int]) -> ColorSpace:
# Needed as a standalone method for JIT
ndim = len(shape)
if ndim < 2:
return ColorSpace.OTHER
elif ndim == 2:
return ColorSpace.GRAY
num_channels = shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
class Image(_Feature):
color_space: ColorSpace
def __new__(
cls,
data: Any,
*,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Image:
data = torch.as_tensor(data, dtype=dtype, device=device) # type: ignore[arg-type]
if data.ndim < 2:
raise ValueError
elif data.ndim == 2:
data = data.unsqueeze(0)
image = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None:
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
image.color_space = color_space
return image
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
@classmethod
def new_like(
cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
) -> Image:
return super().new_like(
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
)
@property
def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property
def num_channels(self) -> int:
return self.shape[-3]
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
return Image.new_like(
self,
self._F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self)
return Image.new_like(self, output)
def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self)
return Image.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Image:
output = self._F.resize_image_tensor(
self, size, interpolation=interpolation, max_size=max_size, antialias=antialias
)
return Image.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Image:
output = self._F.crop_image_tensor(self, top, left, height, width)
return Image.new_like(self, output)
def center_crop(self, output_size: List[int]) -> Image:
output = self._F.center_crop_image_tensor(self, output_size=output_size)
return Image.new_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> Image:
output = self._F.resized_crop_image_tensor(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
)
return Image.new_like(self, output)
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
return Image.new_like(self, output)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Image.new_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.affine_image_tensor(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Image.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill
)
return Image.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor)
return Image.new_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Image:
output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor)
return Image.new_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Image:
output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor)
return Image.new_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Image:
output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor)
return Image.new_like(self, output)
def adjust_hue(self, hue_factor: float) -> Image:
output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor)
return Image.new_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain)
return Image.new_like(self, output)
def posterize(self, bits: int) -> Image:
output = self._F.posterize_image_tensor(self, bits=bits)
return Image.new_like(self, output)
def solarize(self, threshold: float) -> Image:
output = self._F.solarize_image_tensor(self, threshold=threshold)
return Image.new_like(self, output)
def autocontrast(self) -> Image:
output = self._F.autocontrast_image_tensor(self)
return Image.new_like(self, output)
def equalize(self) -> Image:
output = self._F.equalize_image_tensor(self)
return Image.new_like(self, output)
def invert(self) -> Image:
output = self._F.invert_image_tensor(self)
return Image.new_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Image.new_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
ImageTypeJIT = torch.Tensor
LegacyImageType = Union[torch.Tensor, PIL.Image.Image]
LegacyImageTypeJIT = torch.Tensor
TensorImageType = Union[torch.Tensor, Image]
TensorImageTypeJIT = torch.Tensor
from __future__ import annotations
from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
from ._feature import _Feature
L = TypeVar("L", bound="_LabelBase")
class _LabelBase(_Feature):
categories: Optional[Sequence[str]]
def __new__(
cls: Type[L],
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> L:
label_base = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
label_base.categories = categories
return label_base
@classmethod
def new_like(cls: Type[L], other: L, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> L:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
@classmethod
def from_category(
cls: Type[L],
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> L:
return cls(categories.index(category), categories=categories, **kwargs)
class Label(_LabelBase):
def to_categories(self) -> Any:
if self.categories is None:
raise RuntimeError("Label does not have categories")
return tree_map(lambda idx: self.categories[idx], self.tolist())
class OneHotLabel(_LabelBase):
def __new__(
cls,
data: Any,
*,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel:
one_hot_label = super().__new__(
cls, data, categories=categories, dtype=dtype, device=device, requires_grad=requires_grad
)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
return one_hot_label
from __future__ import annotations
from typing import List, Optional, Union
import torch
from torchvision.transforms import InterpolationMode
from ._feature import _Feature, FillTypeJIT
class Mask(_Feature):
def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self)
return Mask.new_like(self, output)
def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self)
return Mask.new_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: bool = False,
) -> Mask:
output = self._F.resize_mask(self, size, max_size=max_size)
return Mask.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self, top, left, height, width)
return Mask.new_like(self, output)
def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self, output_size=output_size)
return Mask.new_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = False,
) -> Mask:
output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
return Mask.new_like(self, output)
def pad(
self,
padding: Union[int, List[int]],
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
return Mask.new_like(self, output)
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
return Mask.new_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
self,
angle,
translate=translate,
scale=scale,
shear=shear,
fill=fill,
center=center,
)
return Mask.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
return Mask.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill)
return Mask.new_like(self, output, dtype=output.dtype)
from .raft_stereo import *
from .crestereo import *
import math
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
all = (
"CREStereo",
"CREStereo_Base_Weights",
"crestereo_base",
)
class ConvexMaskPredictor(nn.Module):
def __init__(
self,
*,
in_channels: int,
hidden_size: int,
upsample_factor: int,
multiplier: float = 0.25,
) -> None:
super().__init__()
self.mask_head = nn.Sequential(
Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3),
# https://arxiv.org/pdf/2003.12039.pdf (Annex section B) for the
# following convolution output size
nn.Conv2d(hidden_size, upsample_factor**2 * 9, 1, padding=0),
)
self.multiplier = multiplier
def forward(self, x: Tensor) -> Tensor:
x = self.mask_head(x) * self.multiplier
return x
def get_correlation(
left_feature: Tensor,
right_feature: Tensor,
window_size: Tuple[int, int] = (3, 3),
dilate: Tuple[int, int] = (1, 1),
) -> Tensor:
"""Function that computes a correlation product between the left and right features.
The correlation is computed in a sliding window fashion, namely the the left features are fixed
and for each ``(i, j)`` location we compute the correlation with a sliding window anchored in
``(i, j)`` from the right feature map. The sliding window selects pixels obtained in the range of the sliding
window; i.e ``(i - window_size // 2, i + window_size // 2)`` respectively ``(j - window_size // 2, j + window_size // 2)``.
"""
B, C, H, W = left_feature.shape
di_y, di_x = dilate[0], dilate[1]
pad_y, pad_x = window_size[0] // 2 * di_y, window_size[1] // 2 * di_x
right_padded = F.pad(right_feature, (pad_x, pad_x, pad_y, pad_y), mode="replicate")
# in order to vectorize the correlation computation over all pixel candidates
# we create multiple shifted right images which we stack on an extra dimension
right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate)
# torch unfold returns a tensor of shape [B, flattened_values, n_selections]
right_padded = right_padded.permute(0, 2, 1)
# we consider rehsape back into [B, n_views, C, H, W]
right_padded = right_padded.reshape(B, (window_size[0] * window_size[1]), C, H, W)
# we expand the left features for broadcasting
left_feature = left_feature.unsqueeze(1)
# this will compute an element product of between [B, 1, C, H, W] * [B, n_views, C, H, W]
# to obtain correlations over the pixel canditates we perform a mean on the C dimension
correlation = torch.mean(left_feature * right_padded, dim=2, keepdim=False)
# the final correlation tensor shape will be [B, n_views, H, W]
# where on the i-th position of the n_views dimension we will have
# the correlation value between the left pixel
# and the i-th candidate on the right feature map
return correlation
def _check_window_specs(
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
if not np.prod(search_window_1d) == np.prod(search_window_2d):
raise ValueError(
f"The 1D and 2D windows should contain the same number of elements. "
f"1D shape: {search_window_1d} 2D shape: {search_window_2d}"
)
if not np.prod(search_window_1d) % 2 == 1:
raise ValueError(
f"Search windows should contain an odd number of elements in them."
f"Window of shape {search_window_1d} has {np.prod(search_window_1d)} elements."
)
if not any(size == 1 for size in search_window_1d):
raise ValueError(f"The 1D search window should have at least one size equal to 1. 1D shape: {search_window_1d}")
if any(size == 1 for size in search_window_2d):
raise ValueError(
f"The 2D search window should have all dimensions greater than 1. 2D shape: {search_window_2d}"
)
if any(dilate < 1 for dilate in search_dilate_1d):
raise ValueError(
f"The 1D search dilation should have all elements equal or greater than 1. 1D shape: {search_dilate_1d}"
)
if any(dilate < 1 for dilate in search_dilate_2d):
raise ValueError(
f"The 2D search dilation should have all elements equal greater than 1. 2D shape: {search_dilate_2d}"
)
class IterativeCorrelationLayer(nn.Module):
def __init__(
self,
groups: int = 4,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
_check_window_specs(
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
self.search_pixels = np.prod(search_window_1d)
self.groups = groups
# two selection tables for dealing withh the small_patch argument in the forward function
self.patch_sizes = {
"2d": [search_window_2d for _ in range(self.groups)],
"1d": [search_window_1d for _ in range(self.groups)],
}
self.dilate_sizes = {
"2d": [search_dilate_2d for _ in range(self.groups)],
"1d": [search_dilate_1d for _ in range(self.groups)],
}
def forward(self, left_feature: Tensor, right_feature: Tensor, flow: Tensor, window_type: str = "1d") -> Tensor:
"""Function that computes 1 pass of non-offsetted Group-Wise correlation"""
coords = make_coords_grid(
left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device)
)
# we offset the coordinate grid in the flow direction
coords = coords + flow
coords = coords.permute(0, 2, 3, 1)
# resample right features according to off-setted grid
right_feature = grid_sample(right_feature, coords, mode="bilinear", align_corners=True)
# use_small_patch is a flag by which we decide on how many axes
# we perform candidate search. See section 3.1 ``Deformable search window`` & Figure 4 in the paper.
patch_size_list = self.patch_sizes[window_type]
dilate_size_list = self.dilate_sizes[window_type]
# chunking the left and right feature to perform group-wise correlation
# mechanism simillar to GroupNorm. See section 3.1 ``Group-wise correlation``.
left_groups = torch.chunk(left_feature, self.groups, dim=1)
right_groups = torch.chunk(right_feature, self.groups, dim=1)
correlations = []
# this boils down to rather than performing the correlation product
# over the entire C dimensions, we use subsets of C to get multiple correlation sets
for i in range(len(patch_size_list)):
correlation = get_correlation(left_groups[i], right_groups[i], patch_size_list[i], dilate_size_list[i])
correlations.append(correlation)
final_correlations = torch.cat(correlations, dim=1)
return final_correlations
class AttentionOffsetCorrelationLayer(nn.Module):
def __init__(
self,
groups: int = 4,
attention_module: Optional[nn.Module] = None,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
_check_window_specs(
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
# convert to python scalar
self.search_pixels = int(np.prod(search_window_1d))
self.groups = groups
# two selection tables for dealing withh the small_patch argument in the forward function
self.patch_sizes = {
"2d": [search_window_2d for _ in range(self.groups)],
"1d": [search_window_1d for _ in range(self.groups)],
}
self.dilate_sizes = {
"2d": [search_dilate_2d for _ in range(self.groups)],
"1d": [search_dilate_1d for _ in range(self.groups)],
}
self.attention_module = attention_module
def forward(
self,
left_feature: Tensor,
right_feature: Tensor,
flow: Tensor,
extra_offset: Tensor,
window_type: str = "1d",
) -> Tensor:
"""Function that computes 1 pass of offsetted Group-Wise correlation
If the class was provided with an attention layer, the left and right feature maps
will be passed through a transformer first
"""
B, C, H, W = left_feature.shape
if self.attention_module is not None:
# prepare for transformer required input shapes
left_feature = left_feature.permute(0, 2, 3, 1).reshape(B, H * W, C)
right_feature = right_feature.permute(0, 2, 3, 1).reshape(B, H * W, C)
# this can be either self attention or cross attention, hence the tupple return
left_feature, right_feature = self.attention_module(left_feature, right_feature)
left_feature = left_feature.reshape(B, H, W, C).permute(0, 3, 1, 2)
right_feature = right_feature.reshape(B, H, W, C).permute(0, 3, 1, 2)
left_groups = torch.chunk(left_feature, self.groups, dim=1)
right_groups = torch.chunk(right_feature, self.groups, dim=1)
num_search_candidates = self.search_pixels
# for each pixel (i, j) we have a number of search candidates
# thus, for each candidate we should have an X-axis and Y-axis offset value
extra_offset = extra_offset.reshape(B, num_search_candidates, 2, H, W).permute(0, 1, 3, 4, 2)
patch_size_list = self.patch_sizes[window_type]
dilate_size_list = self.dilate_sizes[window_type]
group_channels = C // self.groups
correlations = []
for i in range(len(patch_size_list)):
left_group, right_group = left_groups[i], right_groups[i]
patch_size, dilate = patch_size_list[i], dilate_size_list[i]
di_y, di_x = dilate
ps_y, ps_x = patch_size
# define the search based on the window patch shape
ry, rx = ps_y // 2 * di_y, ps_x // 2 * di_x
# base offsets for search (i.e. where to look on the search index)
x_grid, y_grid = torch.meshgrid(
torch.arange(-rx, rx + 1, di_x), torch.arange(-ry, ry + 1, di_y), indexing="xy"
)
x_grid, y_grid = x_grid.to(flow.device), y_grid.to(flow.device)
offsets = torch.stack((x_grid, y_grid))
offsets = offsets.reshape(2, -1).permute(1, 0)
for d in (0, 2, 3):
offsets = offsets.unsqueeze(d)
# extra offsets for search (i.e. deformed search indexes. Simillar concept to deformable convolutions)
offsets = offsets + extra_offset
coords = (
make_coords_grid(
left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device)
)
+ flow
)
coords = coords.permute(0, 2, 3, 1).unsqueeze(1)
coords = coords + offsets
coords = coords.reshape(B, -1, W, 2)
right_group = grid_sample(right_group, coords, mode="bilinear", align_corners=True)
# we do not need to perform any window shifting because the grid sample op
# will return a multi-view right based on the num_search_candidates dimension in the offsets
right_group = right_group.reshape(B, group_channels, -1, H, W)
left_group = left_group.reshape(B, group_channels, -1, H, W)
correlation = torch.mean(left_group * right_group, dim=1)
correlations.append(correlation)
final_correlation = torch.cat(correlations, dim=1)
return final_correlation
class AdaptiveGroupCorrelationLayer(nn.Module):
"""
Container for computing various correlation types between a left and right feature map.
This module does not contain any optimisable parameters, it's solely a collection of ops.
We wrap in a nn.Module for torch.jit.script compatibility
Adaptive Group Correlation operations from: https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf
Canonical reference implementation: https://github.com/megvii-research/CREStereo/blob/master/nets/corr.py
"""
def __init__(
self,
iterative_correlation_layer: IterativeCorrelationLayer,
attention_offset_correlation_layer: AttentionOffsetCorrelationLayer,
) -> None:
super().__init__()
self.iterative_correlation_layer = iterative_correlation_layer
self.attention_offset_correlation_layer = attention_offset_correlation_layer
def forward(
self,
left_features: Tensor,
right_features: Tensor,
flow: torch.Tensor,
extra_offset: Optional[Tensor],
window_type: str = "1d",
iter_mode: bool = False,
) -> Tensor:
if iter_mode or extra_offset is None:
corr = self.iterative_correlation_layer(left_features, right_features, flow, window_type)
else:
corr = self.attention_offset_correlation_layer(
left_features, right_features, flow, extra_offset, window_type
) # type: ignore
return corr
def elu_feature_map(x: Tensor) -> Tensor:
"""Elu feature map operation from: https://arxiv.org/pdf/2006.16236.pdf"""
return F.elu(x) + 1
class LinearAttention(nn.Module):
"""
Linear attention operation from: https://arxiv.org/pdf/2006.16236.pdf
Cannonical implementation reference: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py
"""
def __init__(self, eps: float = 1e-6, feature_map_fn: Callable[[Tensor], Tensor] = elu_feature_map) -> None:
super().__init__()
self.eps = eps
self.feature_map_fn = feature_map_fn
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
q_mask: Optional[Tensor] = None,
kv_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
queries (torch.Tensor): [N, S1, H, D]
keys (torch.Tensor): [N, S2, H, D]
values (torch.Tensor): [N, S2, H, D]
q_mask (torch.Tensor): [N, S1] (optional)
kv_mask (torch.Tensor): [N, S2] (optional)
Returns:
queried_values (torch.Tensor): [N, S1, H, D]
"""
queries = self.feature_map_fn(queries)
keys = self.feature_map_fn(keys)
if q_mask is not None:
queries = queries * q_mask[:, :, None, None]
if kv_mask is not None:
keys = keys * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]
# mitigates fp16 overflows
values_length = values.shape[1]
values = values / values_length
kv = torch.einsum("NSHD, NSHV -> NHDV", keys, values)
z = 1 / (torch.einsum("NLHD, NHD -> NLH", queries, keys.sum(dim=1)) + self.eps)
# rescale at the end to account for fp16 mitigation
queried_values = torch.einsum("NLHD, NHDV, NLH -> NLHV", queries, kv, z) * values_length
return queried_values
class SoftmaxAttention(nn.Module):
"""
A simple softmax attention operation
LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py
"""
def __init__(self, dropout: float = 0.0) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout) if dropout else nn.Identity()
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
q_mask: Optional[Tensor] = None,
kv_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Computes classical softmax full-attention between all queries and keys.
Args:
queries (torch.Tensor): [N, S1, H, D]
keys (torch.Tensor): [N, S2, H, D]
values (torch.Tensor): [N, S2, H, D]
q_mask (torch.Tensor): [N, S1] (optional)
kv_mask (torch.Tensor): [N, S2] (optional)
Returns:
queried_values: [N, S1, H, D]
"""
scale_factor = 1.0 / queries.shape[3] ** 0.5 # irsqrt(D) scaling
queries = queries * scale_factor
qk = torch.einsum("NLHD, NSHD -> NLSH", queries, keys)
if kv_mask is not None and q_mask is not None:
qk.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf"))
attention = torch.softmax(qk, dim=2)
attention = self.dropout(attention)
queried_values = torch.einsum("NLSH, NSHD -> NLHD", attention, values)
return queried_values
class PositionalEncodingSine(nn.Module):
"""
Sinusoidal positonal encodings
Using the scaling term from https://github.com/megvii-research/CREStereo/blob/master/nets/attention/position_encoding.py
Reference implementation from https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/position_encoding.py#L28-L48
"""
def __init__(self, dim_model: int, max_size: int = 256) -> None:
super().__init__()
self.dim_model = dim_model
self.max_size = max_size
# pre-registered for memory efficiency during forward pass
pe = self._make_pe_of_size(self.max_size)
self.register_buffer("pe", pe)
def _make_pe_of_size(self, size: int) -> Tensor:
pe = torch.zeros((self.dim_model, *(size, size)), dtype=torch.float32)
y_positions = torch.ones((size, size)).cumsum(0).float().unsqueeze(0)
x_positions = torch.ones((size, size)).cumsum(1).float().unsqueeze(0)
div_term = torch.exp(torch.arange(0.0, self.dim_model // 2, 2) * (-math.log(10000.0) / self.dim_model // 2))
div_term = div_term[:, None, None]
pe[0::4, :, :] = torch.sin(x_positions * div_term)
pe[1::4, :, :] = torch.cos(x_positions * div_term)
pe[2::4, :, :] = torch.sin(y_positions * div_term)
pe[3::4, :, :] = torch.cos(y_positions * div_term)
pe = pe.unsqueeze(0)
return pe
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: [B, C, H, W]
Returns:
x: [B, C, H, W]
"""
torch._assert(
len(x.shape) == 4,
f"PositionalEncodingSine requires a 4-D dimensional input. Provided tensor is of shape {x.shape}",
)
B, C, H, W = x.shape
return x + self.pe[:, :, :H, :W] # type: ignore
class LocalFeatureEncoderLayer(nn.Module):
"""
LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf
Cannonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
"""
def __init__(
self,
*,
dim_model: int,
num_heads: int,
attention_module: Callable[..., nn.Module] = LinearAttention,
) -> None:
super().__init__()
self.attention_op = attention_module()
if not isinstance(self.attention_op, (LinearAttention, SoftmaxAttention)):
raise ValueError(
f"attention_module must be an instance of LinearAttention or SoftmaxAttention. Got {type(self.attention_op)}"
)
self.dim_head = dim_model // num_heads
self.num_heads = num_heads
# multi-head attention
self.query_proj = nn.Linear(dim_model, dim_model, bias=False)
self.key_proj = nn.Linear(dim_model, dim_model, bias=False)
self.value_proj = nn.Linear(dim_model, dim_model, bias=False)
self.merge = nn.Linear(dim_model, dim_model, bias=False)
# feed forward network
self.ffn = nn.Sequential(
nn.Linear(dim_model * 2, dim_model * 2, bias=False),
nn.ReLU(),
nn.Linear(dim_model * 2, dim_model, bias=False),
)
# norm layers
self.attention_norm = nn.LayerNorm(dim_model)
self.ffn_norm = nn.LayerNorm(dim_model)
def forward(
self, x: Tensor, source: Tensor, x_mask: Optional[Tensor] = None, source_mask: Optional[Tensor] = None
) -> Tensor:
"""
Args:
x (torch.Tensor): [B, S1, D]
source (torch.Tensor): [B, S2, D]
x_mask (torch.Tensor): [B, S1] (optional)
source_mask (torch.Tensor): [B, S2] (optional)
"""
B, S, D = x.shape
queries, keys, values = x, source, source
queries = self.query_proj(queries).reshape(B, S, self.num_heads, self.dim_head)
keys = self.key_proj(keys).reshape(B, S, self.num_heads, self.dim_head)
values = self.value_proj(values).reshape(B, S, self.num_heads, self.dim_head)
# attention operation
message = self.attention_op(queries, keys, values, x_mask, source_mask)
# concatenating attention heads together before passing throught projection layer
message = self.merge(message.reshape(B, S, D))
message = self.attention_norm(message)
# ffn operation
message = self.ffn(torch.cat([x, message], dim=2))
message = self.ffn_norm(message)
return x + message
class LocalFeatureTransformer(nn.Module):
"""
LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf
Cannonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
"""
def __init__(
self,
*,
dim_model: int,
num_heads: int,
attention_directions: List[str],
attention_module: Callable[..., nn.Module] = LinearAttention,
) -> None:
super(LocalFeatureTransformer, self).__init__()
self.attention_module = attention_module
self.attention_directions = attention_directions
for direction in attention_directions:
if direction not in ["self", "cross"]:
raise ValueError(
f"Attention direction {direction} unsupported. LocalFeatureTransformer accepts only ``attention_type`` in ``[self, cross]``."
)
self.layers = nn.ModuleList(
[
LocalFeatureEncoderLayer(dim_model=dim_model, num_heads=num_heads, attention_module=attention_module)
for _ in attention_directions
]
)
def forward(
self,
left_features: Tensor,
right_features: Tensor,
left_mask: Optional[Tensor] = None,
right_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
left_features (torch.Tensor): [N, S1, D]
right_features (torch.Tensor): [N, S2, D]
left_mask (torch.Tensor): [N, S1] (optional)
right_mask (torch.Tensor): [N, S2] (optional)
Returns:
left_features (torch.Tensor): [N, S1, D]
right_features (torch.Tensor): [N, S2, D]
"""
torch._assert(
left_features.shape[2] == right_features.shape[2],
f"left_features and right_features should have the same embedding dimensions. left_features: {left_features.shape[2]} right_features: {right_features.shape[2]}",
)
for idx, layer in enumerate(self.layers):
attention_direction = self.attention_directions[idx]
if attention_direction == "self":
left_features = layer(left_features, left_features, left_mask, left_mask)
right_features = layer(right_features, right_features, right_mask, right_mask)
elif attention_direction == "cross":
left_features = layer(left_features, right_features, left_mask, right_mask)
right_features = layer(right_features, left_features, right_mask, left_mask)
return left_features, right_features
class PyramidDownsample(nn.Module):
"""
A simple wrapper that return and Avg Pool feature pyramid based on the provided scales.
Implicitly returns the input as well.
"""
def __init__(self, factors: Iterable[int]) -> None:
super().__init__()
self.factors = factors
def forward(self, x: torch.Tensor) -> List[Tensor]:
results = [x]
for factor in self.factors:
results.append(F.avg_pool2d(x, kernel_size=factor, stride=factor))
return results
class CREStereo(nn.Module):
"""
Implements CREStereo from the `"Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation" <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_ paper.
Args:
feature_encoder (raft.FeatureEncoder): Raft-like Feature Encoder module extract low-level features from inputs.
update_block (raft.UpdateBlock): Raft-like Update Block which recursively refines a flow-map.
flow_head (raft.FlowHead): Raft-like Flow Head which predics a flow-map from some inputs.
self_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs self attention on the two feature maps.
cross_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs cross attention between the two feature maps
used in the Adaptive Group Correlation module.
feature_downsample_rates (List[int]): The downsample rates used to build a feature pyramid from the outputs of the `feature_encoder`. Default: [2, 4]
correlation_groups (int): In how many groups should the features be split when computer per-pixel correlation. Defaults 4.
search_window_1d (Tuple[int, int]): The alternate search window size in the x and y directions for the 1D case. Defaults to (1, 9).
search_dilate_1d (Tuple[int, int]): The dilation used in the `search_window_1d` when selecting pixels. Simillar to `nn.Conv2d` dilate. Defaults to (1, 1).
search_window_2d (Tuple[int, int]): The alternate search window size in the x and y directions for the 2D case. Defaults to (3, 3).
search_dilate_2d (Tuple[int, int]): The dilation used in the `search_window_2d` when selecting pixels. Simillar to `nn.Conv2d` dilate. Defaults to (1, 1).
"""
def __init__(
self,
*,
feature_encoder: raft.FeatureEncoder,
update_block: raft.UpdateBlock,
flow_head: raft.FlowHead,
self_attn_block: LocalFeatureTransformer,
cross_attn_block: LocalFeatureTransformer,
feature_downsample_rates: Tuple[int, ...] = (2, 4),
correlation_groups: int = 4,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
self.output_channels = 2
self.feature_encoder = feature_encoder
self.update_block = update_block
self.flow_head = flow_head
self.self_attn_block = self_attn_block
# average pooling for the feature encoder outputs
self.downsampling_pyramid = PyramidDownsample(feature_downsample_rates)
self.downsampling_factors: List[int] = [feature_encoder.downsample_factor]
base_downsample_factor: int = self.downsampling_factors[0]
for rate in feature_downsample_rates:
self.downsampling_factors.append(base_downsample_factor * rate)
# output resolution tracking
self.resolutions: List[str] = [f"1 / {factor}" for factor in self.downsampling_factors]
self.search_pixels = int(np.prod(search_window_1d))
# flow convex upsampling mask predictor
self.mask_predictor = ConvexMaskPredictor(
in_channels=feature_encoder.output_dim // 2,
hidden_size=feature_encoder.output_dim,
upsample_factor=feature_encoder.downsample_factor,
multiplier=0.25,
)
# offsets modules for offseted feature selection
self.offset_convs = nn.ModuleDict()
self.correlation_layers = nn.ModuleDict()
offset_conv_layer = partial(
Conv2dNormActivation,
in_channels=feature_encoder.output_dim,
out_channels=self.search_pixels * 2,
norm_layer=None,
activation_layer=None,
)
# populate the dicts in top to bottom order
# useful for iterating through torch.jit.script module given the network forward pass
#
# Ignore the largest resolution. We handle that separately due to torch.jit.script
# not being to able access to runtime generated keys in ModuleDicts.
# This way, we can keep a generic way of processing all pyramid levels but except
# the final one
iterative_correlation_layer = partial(
IterativeCorrelationLayer,
groups=correlation_groups,
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
attention_offset_correlation_layer = partial(
AttentionOffsetCorrelationLayer,
groups=correlation_groups,
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
for idx, resolution in enumerate(reversed(self.resolutions[1:])):
# the largest resolution does use offset convolutions for sampling grid coords
offset_conv = None if idx == len(self.resolutions) - 1 else offset_conv_layer()
if offset_conv:
self.offset_convs[resolution] = offset_conv
# only the lowest resolution uses the cross attention module when computing correlation scores
attention_module = cross_attn_block if idx == 0 else None
self.correlation_layers[resolution] = AdaptiveGroupCorrelationLayer(
iterative_correlation_layer=iterative_correlation_layer(),
attention_offset_correlation_layer=attention_offset_correlation_layer(
attention_module=attention_module
),
)
# correlation layer for the largest resolution
self.max_res_correlation_layer = AdaptiveGroupCorrelationLayer(
iterative_correlation_layer=iterative_correlation_layer(),
attention_offset_correlation_layer=attention_offset_correlation_layer(),
)
# simple 2D Postional Encodings
self.positional_encodings = PositionalEncodingSine(feature_encoder.output_dim)
def _get_window_type(self, iteration: int) -> str:
return "1d" if iteration % 2 == 0 else "2d"
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor], num_iters: int = 10
) -> List[Tensor]:
features = torch.cat([left_image, right_image], dim=0)
features = self.feature_encoder(features)
left_features, right_features = features.chunk(2, dim=0)
# update block network state and input context are derived from the left feature map
net, ctx = left_features.chunk(2, dim=1)
net = torch.tanh(net)
ctx = torch.relu(ctx)
# will output lists of tensor.
l_pyramid = self.downsampling_pyramid(left_features)
r_pyramid = self.downsampling_pyramid(right_features)
net_pyramid = self.downsampling_pyramid(net)
ctx_pyramid = self.downsampling_pyramid(ctx)
# we store in reversed order because we process the pyramid from top to bottom
l_pyramid: Dict[str, Tensor] = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid: Dict[str, Tensor] = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid: Dict[str, Tensor] = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid: Dict[str, Tensor] = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
# offsets for sampling pixel candidates in the correlation ops
offsets: Dict[str, Tensor] = {}
for resolution, offset_conv in self.offset_convs.items():
feature_map = l_pyramid[resolution]
offset = offset_conv(feature_map)
offsets[resolution] = (torch.sigmoid(offset) - 0.5) * 2.0
# the smallest resolution is prepared for passing through self attention
min_res = self.resolutions[-1]
max_res = self.resolutions[0]
B, C, MIN_H, MIN_W = l_pyramid[min_res].shape
# add positional encodings
l_pyramid[min_res] = self.positional_encodings(l_pyramid[min_res])
r_pyramid[min_res] = self.positional_encodings(r_pyramid[min_res])
# reshaping for transformer
l_pyramid[min_res] = l_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C)
r_pyramid[min_res] = r_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C)
# perform self attention
l_pyramid[min_res], r_pyramid[min_res] = self.self_attn_block(l_pyramid[min_res], r_pyramid[min_res])
# now we need to reshape back into [B, C, H, W] format
l_pyramid[min_res] = l_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2)
r_pyramid[min_res] = r_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2)
predictions: List[Tensor] = []
flow_estimates: Dict[str, Tensor] = {}
# we added this because of torch.script.jit
# also, the predicition prior is always going to have the
# spatial size of the features outputed by the feature encoder
flow_pred_prior: Tensor = torch.empty(
size=(B, 2, left_features.shape[2], left_features.shape[3]),
dtype=l_pyramid[max_res].dtype,
device=l_pyramid[max_res].device,
)
if flow_init is not None:
scale = l_pyramid[max_res].shape[2] / flow_init.shape[2]
# in CREStereo implementation they multiply with -scale instead of scale
# this can be either a downsample or an upsample based on the cascaded inference
# configuration
# we use a -scale because the flow used inside the network is a negative flow
# from the right to the left, so we flip the flow direction
flow_estimates[max_res] = -scale * F.interpolate(
input=flow_init,
size=l_pyramid[max_res].shape[2:],
mode="bilinear",
align_corners=True,
)
# when not provided with a flow prior, we construct one using the lower resolution maps
else:
# initialize a zero flow with the smallest resolution
flow = torch.zeros(size=(B, 2, MIN_H, MIN_W), device=left_features.device, dtype=left_features.dtype)
# flows from coarse resolutions are refined similarly
# we always need to fetch the next pyramid feature map as well
# when updating coarse resolutions, therefore we create a reversed
# view which has its order synced with the ModuleDict keys iterator
coarse_resolutions: List[str] = self.resolutions[::-1] # using slicing because of torch.jit.script
fine_grained_resolution = max_res
# set the coarsest flow to the zero flow
flow_estimates[coarse_resolutions[0]] = flow
# the correlation layer ModuleDict will contain layers ordered from coarse to fine resolution
# i.e ["1 / 16", "1 / 8", "1 / 4"]
# the correlation layer ModuleDict has layers for all the resolutions except the fine one
# i.e {"1 / 16": Module, "1 / 8": Module}
# for these resolution we perform only half of the number of refinement iterations
for idx, (resolution, correlation_layer) in enumerate(self.correlation_layers.items()):
# compute the scale difference between the first pyramid scale and the current pyramid scale
scale_to_base = l_pyramid[fine_grained_resolution].shape[2] // l_pyramid[resolution].shape[2]
for it in range(num_iters // 2):
# set wether or not we want to search on (X, Y) axes for correlation or just on X axis
window_type = self._get_window_type(it)
# we consider this a prior, therefor we do not want to back-propagate through it
flow_estimates[resolution] = flow_estimates[resolution].detach()
correlations = correlation_layer(
l_pyramid[resolution], # left
r_pyramid[resolution], # right
flow_estimates[resolution],
offsets[resolution],
window_type,
)
# update the recurrent network state and the flow deltas
net_pyramid[resolution], delta_flow = self.update_block(
net_pyramid[resolution], ctx_pyramid[resolution], correlations, flow_estimates[resolution]
)
# the convex upsampling weights are computed w.r.t.
# the recurrent update state
up_mask = self.mask_predictor(net_pyramid[resolution])
flow_estimates[resolution] = flow_estimates[resolution] + delta_flow
# convex upsampling with the initial feature encoder downsampling rate
flow_pred_prior = upsample_flow(
flow_estimates[resolution], up_mask, factor=self.downsampling_factors[0]
)
# we then bilinear upsample to the final resolution
# we use a factor that's equivalent to the difference between
# the current downsample resolution and the base downsample resolution
#
# i.e. if a 1 / 16 flow is upsampled by 4 (base downsampling) we get a 1 / 4 flow.
# therefore we have to further upscale it by the difference between
# the current level 1 / 16 and the base level 1 / 4.
#
# we use a -scale because the flow used inside the network is a negative flow
# from the right to the left, so we flip the flow direction in order to get the
# left to right flow
flow_pred = -upsample_flow(flow_pred_prior, None, factor=scale_to_base)
predictions.append(flow_pred)
# when constructing the next resolution prior, we resample w.r.t
# to the scale of the next level in the pyramid
next_resolution = coarse_resolutions[idx + 1]
scale_to_next = l_pyramid[next_resolution].shape[2] / flow_pred_prior.shape[2]
# we use the flow_up_prior because this is a more accurate estimation of the true flow
# due to the convex upsample, which resembles a learned super-resolution module.
# this is not necessarily an upsample, it can be a downsample, based on the provided configuration
flow_estimates[next_resolution] = -scale_to_next * F.interpolate(
input=flow_pred_prior,
size=l_pyramid[next_resolution].shape[2:],
mode="bilinear",
align_corners=True,
)
# finally we will be doing a full pass through the fine-grained resolution
# this coincides with the maximum resolution
# we keep a separate loop here in order to avoid python control flow
# to decide how much iterations should we do based on the current resolution
# further more, if provided with an inital flow, there is no need to generate
# a prior estimate when moving into the final refinement stage
for it in range(num_iters):
search_window_type = self._get_window_type(it)
flow_estimates[max_res] = flow_estimates[max_res].detach()
# we run the fine-grained resolution correlations in iterative mode
# this means that we are using the fixed window pixel selections
# instead of the deformed ones as with the previous steps
correlations = self.max_res_correlation_layer(
l_pyramid[max_res],
r_pyramid[max_res],
flow_estimates[max_res],
extra_offset=None,
window_type=search_window_type,
iter_mode=True,
)
net_pyramid[max_res], delta_flow = self.update_block(
net_pyramid[max_res], ctx_pyramid[max_res], correlations, flow_estimates[max_res]
)
up_mask = self.mask_predictor(net_pyramid[max_res])
flow_estimates[max_res] = flow_estimates[max_res] + delta_flow
# at the final resolution we simply do a convex upsample using the base downsample rate
flow_pred = -upsample_flow(flow_estimates[max_res], up_mask, factor=self.downsampling_factors[0])
predictions.append(flow_pred)
return predictions
def _crestereo(
*,
weights: Optional[WeightsEnum],
progress: bool,
# Feature Encoder
feature_encoder_layers: Tuple[int, int, int, int, int],
feature_encoder_strides: Tuple[int, int, int, int],
feature_encoder_block: Callable[..., nn.Module],
feature_encoder_norm_layer: Callable[..., nn.Module],
# Average Pooling Pyramid
feature_downsample_rates: Tuple[int, ...],
# Adaptive Correlation Layer
corr_groups: int,
corr_search_window_2d: Tuple[int, int],
corr_search_dilate_2d: Tuple[int, int],
corr_search_window_1d: Tuple[int, int],
corr_search_dilate_1d: Tuple[int, int],
# Flow head
flow_head_hidden_size: int,
# Recurrent block
recurrent_block_hidden_state_size: int,
recurrent_block_kernel_size: Tuple[Tuple[int, int], Tuple[int, int]],
recurrent_block_padding: Tuple[Tuple[int, int], Tuple[int, int]],
# Motion Encoder
motion_encoder_corr_layers: Tuple[int, int],
motion_encoder_flow_layers: Tuple[int, int],
motion_encoder_out_channels: int,
# Transformer Blocks
num_attention_heads: int,
num_self_attention_layers: int,
num_cross_attention_layers: int,
self_attention_module: Callable[..., nn.Module],
cross_attention_module: Callable[..., nn.Module],
**kwargs,
) -> CREStereo:
feature_encoder = kwargs.pop("feature_encoder", None) or raft.FeatureEncoder(
block=feature_encoder_block,
layers=feature_encoder_layers,
strides=feature_encoder_strides,
norm_layer=feature_encoder_norm_layer,
)
if feature_encoder.output_dim % corr_groups != 0:
raise ValueError(
f"Final ``feature_encoder_layers`` size should be divisible by ``corr_groups`` argument."
f"Feature encoder output size : {feature_encoder.output_dim}, Correlation groups: {corr_groups}."
)
motion_encoder = kwargs.pop("motion_encoder", None) or raft.MotionEncoder(
in_channels_corr=corr_groups * int(np.prod(corr_search_window_1d)),
corr_layers=motion_encoder_corr_layers,
flow_layers=motion_encoder_flow_layers,
out_channels=motion_encoder_out_channels,
)
out_channels_context = feature_encoder_layers[-1] - recurrent_block_hidden_state_size
recurrent_block = kwargs.pop("recurrent_block", None) or raft.RecurrentBlock(
input_size=motion_encoder.out_channels + out_channels_context,
hidden_size=recurrent_block_hidden_state_size,
kernel_size=recurrent_block_kernel_size,
padding=recurrent_block_padding,
)
flow_head = kwargs.pop("flow_head", None) or raft.FlowHead(
in_channels=out_channels_context, hidden_size=flow_head_hidden_size
)
update_block = raft.UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
self_attention_module = kwargs.pop("self_attention_module", None) or LinearAttention
self_attn_block = LocalFeatureTransformer(
dim_model=feature_encoder.output_dim,
num_heads=num_attention_heads,
attention_directions=["self"] * num_self_attention_layers,
attention_module=self_attention_module,
)
cross_attention_module = kwargs.pop("cross_attention_module", None) or LinearAttention
cross_attn_block = LocalFeatureTransformer(
dim_model=feature_encoder.output_dim,
num_heads=num_attention_heads,
attention_directions=["cross"] * num_cross_attention_layers,
attention_module=cross_attention_module,
)
model = CREStereo(
feature_encoder=feature_encoder,
update_block=update_block,
flow_head=flow_head,
self_attn_block=self_attn_block,
cross_attn_block=cross_attn_block,
feature_downsample_rates=feature_downsample_rates,
correlation_groups=corr_groups,
search_window_1d=corr_search_window_1d,
search_window_2d=corr_search_window_2d,
search_dilate_1d=corr_search_dilate_1d,
search_dilate_2d=corr_search_dilate_2d,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"resize_size": (384, 512),
}
class CREStereo_Base_Weights(WeightsEnum):
"""The metrics reported here are as follows.
``mae`` is the "mean-average-error" and indicates how far (in pixels) the
predicted disparity is from its true value (equivalent to ``epe``). This is averaged over all pixels
of all images. ``1px``, ``3px``, ``5px`` and indicate the percentage of pixels that have a lower
error than that of the ground truth. ``relepe`` is the "relative-end-point-error" and is the
average ``epe`` divided by the average ground truth disparity. ``fl-all`` corresponds to the average of pixels whose epe
is either <3px, or whom's ``relepe`` is lower than 0.05 (therefore higher is better).
"""
MEGVII_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-756c8b0f.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/megvii-research/CREStereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement interations
2: {
"mae": 1.704,
"rmse": 3.738,
"1px": 0.738,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.157,
"fl-all": 76.464,
},
5: {
"mae": 0.956,
"rmse": 2.963,
"1px": 0.88,
"3px": 0.948,
"5px": 0.965,
"relepe": 0.124,
"fl-all": 88.186,
},
10: {
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
},
20: {
"mae": 0.749,
"rmse": 2.706,
"1px": 0.907,
"3px": 0.961,
"5px": 0.972,
"relepe": 0.113,
"fl-all": 90.807,
},
},
2: {
2: {
"mae": 1.702,
"rmse": 3.784,
"1px": 0.784,
"3px": 0.894,
"5px": 0.924,
"relepe": 0.172,
"fl-all": 80.313,
},
5: {
"mae": 0.932,
"rmse": 2.907,
"1px": 0.877,
"3px": 0.944,
"5px": 0.963,
"relepe": 0.125,
"fl-all": 87.979,
},
10: {
"mae": 0.773,
"rmse": 2.768,
"1px": 0.901,
"3px": 0.958,
"5px": 0.972,
"relepe": 0.117,
"fl-all": 90.43,
},
20: {
"mae": 0.854,
"rmse": 2.971,
"1px": 0.9,
"3px": 0.957,
"5px": 0.97,
"relepe": 0.122,
"fl-all": 90.269,
},
},
},
}
},
"_docs": """These weights were ported from the original paper. They
are trained on a dataset mixture of the author's choice.""",
},
)
CRESTEREO_ETH_MBL_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-8f0e0e9a.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is the number of refinement iterations
2: {
"mae": 2.363,
"rmse": 4.352,
"1px": 0.611,
"3px": 0.828,
"5px": 0.891,
"relepe": 0.176,
"fl-all": 64.511,
},
5: {
"mae": 1.618,
"rmse": 3.71,
"1px": 0.761,
"3px": 0.879,
"5px": 0.918,
"relepe": 0.154,
"fl-all": 77.128,
},
10: {
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
},
20: {
"mae": 1.448,
"rmse": 3.583,
"1px": 0.771,
"3px": 0.893,
"5px": 0.931,
"relepe": 0.145,
"fl-all": 77.7,
},
},
2: {
2: {
"mae": 1.972,
"rmse": 4.125,
"1px": 0.73,
"3px": 0.865,
"5px": 0.908,
"relepe": 0.169,
"fl-all": 74.396,
},
5: {
"mae": 1.403,
"rmse": 3.448,
"1px": 0.793,
"3px": 0.905,
"5px": 0.937,
"relepe": 0.151,
"fl-all": 80.186,
},
10: {
"mae": 1.312,
"rmse": 3.368,
"1px": 0.799,
"3px": 0.912,
"5px": 0.943,
"relepe": 0.148,
"fl-all": 80.379,
},
20: {
"mae": 1.376,
"rmse": 3.542,
"1px": 0.796,
"3px": 0.91,
"5px": 0.942,
"relepe": 0.149,
"fl-all": 80.054,
},
},
},
}
},
"_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo`.""",
},
)
CRESTEREO_FINETUNE_MULTI_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-697c38f4.pth ",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.038,
"rmse": 3.108,
"1px": 0.852,
"3px": 0.942,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.522,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement interations
2: {
"mae": 1.85,
"rmse": 3.797,
"1px": 0.673,
"3px": 0.862,
"5px": 0.917,
"relepe": 0.171,
"fl-all": 69.736,
},
5: {
"mae": 1.111,
"rmse": 3.166,
"1px": 0.838,
"3px": 0.93,
"5px": 0.957,
"relepe": 0.134,
"fl-all": 84.596,
},
10: {
"mae": 1.02,
"rmse": 3.073,
"1px": 0.854,
"3px": 0.938,
"5px": 0.96,
"relepe": 0.129,
"fl-all": 86.042,
},
20: {
"mae": 0.993,
"rmse": 3.059,
"1px": 0.855,
"3px": 0.942,
"5px": 0.967,
"relepe": 0.126,
"fl-all": 85.784,
},
},
2: {
2: {
"mae": 1.667,
"rmse": 3.867,
"1px": 0.78,
"3px": 0.891,
"5px": 0.922,
"relepe": 0.165,
"fl-all": 78.89,
},
5: {
"mae": 1.158,
"rmse": 3.278,
"1px": 0.843,
"3px": 0.926,
"5px": 0.955,
"relepe": 0.135,
"fl-all": 84.556,
},
10: {
"mae": 1.046,
"rmse": 3.13,
"1px": 0.85,
"3px": 0.934,
"5px": 0.96,
"relepe": 0.13,
"fl-all": 85.464,
},
20: {
"mae": 1.021,
"rmse": 3.102,
"1px": 0.85,
"3px": 0.935,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.417,
},
},
},
},
},
"_docs": """These weights were finetuned on a mixture of
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo` +
:class:`~torchvision.datasets._stereo_matching.InStereo2k` +
:class:`~torchvision.datasets._stereo_matching.CarlaStereo` +
:class:`~torchvision.datasets._stereo_matching.SintelStereo` +
:class:`~torchvision.datasets._stereo_matching.FallingThingsStereo` +
.""",
},
)
DEFAULT = MEGVII_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", CREStereo_Base_Weights.MEGVII_V1))
def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress=True, **kwargs) -> CREStereo:
"""CREStereo model from
`Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/crestereo.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
"""
return _crestereo(
weights=weights,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(2, 1, 2, 1),
feature_encoder_block=partial(raft.ResidualBlock, always_project=True),
feature_encoder_norm_layer=nn.InstanceNorm2d,
# Average pooling pyramid
feature_downsample_rates=(2, 4),
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Transformer blocks
num_attention_heads=8,
num_self_attention_layers=1,
num_cross_attention_layers=1,
self_attention_module=LinearAttention,
cross_attention_module=LinearAttention,
# Adaptive Correlation layer
corr_groups=4,
corr_search_window_2d=(3, 3),
corr_search_dilate_2d=(1, 1),
corr_search_window_1d=(1, 9),
corr_search_dilate_1d=(1, 1),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment