Unverified Commit fbb4cc54 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove torchvision.prototype module and related tests / CI from release branch (#7983)

parent a90e5846
import pathlib
from typing import Any, Dict, List, Tuple, Union
import torch
from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import OneHotLabel
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "semeion"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(i) for i in range(10)])
@register_dataset(NAME)
class SEMEION(Dataset):
"""Semeion dataset
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
"""
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
data = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
)
return [data]
def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]:
image_data, label_data = data[:256], data[256:-1]
return dict(
image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)),
label=OneHotLabel([int(label) for label in label_data], categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 1_593
AM General Hummer SUV 2000
Acura RL Sedan 2012
Acura TL Sedan 2012
Acura TL Type-S 2008
Acura TSX Sedan 2012
Acura Integra Type R 2001
Acura ZDX Hatchback 2012
Aston Martin V8 Vantage Convertible 2012
Aston Martin V8 Vantage Coupe 2012
Aston Martin Virage Convertible 2012
Aston Martin Virage Coupe 2012
Audi RS 4 Convertible 2008
Audi A5 Coupe 2012
Audi TTS Coupe 2012
Audi R8 Coupe 2012
Audi V8 Sedan 1994
Audi 100 Sedan 1994
Audi 100 Wagon 1994
Audi TT Hatchback 2011
Audi S6 Sedan 2011
Audi S5 Convertible 2012
Audi S5 Coupe 2012
Audi S4 Sedan 2012
Audi S4 Sedan 2007
Audi TT RS Coupe 2012
BMW ActiveHybrid 5 Sedan 2012
BMW 1 Series Convertible 2012
BMW 1 Series Coupe 2012
BMW 3 Series Sedan 2012
BMW 3 Series Wagon 2012
BMW 6 Series Convertible 2007
BMW X5 SUV 2007
BMW X6 SUV 2012
BMW M3 Coupe 2012
BMW M5 Sedan 2010
BMW M6 Convertible 2010
BMW X3 SUV 2012
BMW Z4 Convertible 2012
Bentley Continental Supersports Conv. Convertible 2012
Bentley Arnage Sedan 2009
Bentley Mulsanne Sedan 2011
Bentley Continental GT Coupe 2012
Bentley Continental GT Coupe 2007
Bentley Continental Flying Spur Sedan 2007
Bugatti Veyron 16.4 Convertible 2009
Bugatti Veyron 16.4 Coupe 2009
Buick Regal GS 2012
Buick Rainier SUV 2007
Buick Verano Sedan 2012
Buick Enclave SUV 2012
Cadillac CTS-V Sedan 2012
Cadillac SRX SUV 2012
Cadillac Escalade EXT Crew Cab 2007
Chevrolet Silverado 1500 Hybrid Crew Cab 2012
Chevrolet Corvette Convertible 2012
Chevrolet Corvette ZR1 2012
Chevrolet Corvette Ron Fellows Edition Z06 2007
Chevrolet Traverse SUV 2012
Chevrolet Camaro Convertible 2012
Chevrolet HHR SS 2010
Chevrolet Impala Sedan 2007
Chevrolet Tahoe Hybrid SUV 2012
Chevrolet Sonic Sedan 2012
Chevrolet Express Cargo Van 2007
Chevrolet Avalanche Crew Cab 2012
Chevrolet Cobalt SS 2010
Chevrolet Malibu Hybrid Sedan 2010
Chevrolet TrailBlazer SS 2009
Chevrolet Silverado 2500HD Regular Cab 2012
Chevrolet Silverado 1500 Classic Extended Cab 2007
Chevrolet Express Van 2007
Chevrolet Monte Carlo Coupe 2007
Chevrolet Malibu Sedan 2007
Chevrolet Silverado 1500 Extended Cab 2012
Chevrolet Silverado 1500 Regular Cab 2012
Chrysler Aspen SUV 2009
Chrysler Sebring Convertible 2010
Chrysler Town and Country Minivan 2012
Chrysler 300 SRT-8 2010
Chrysler Crossfire Convertible 2008
Chrysler PT Cruiser Convertible 2008
Daewoo Nubira Wagon 2002
Dodge Caliber Wagon 2012
Dodge Caliber Wagon 2007
Dodge Caravan Minivan 1997
Dodge Ram Pickup 3500 Crew Cab 2010
Dodge Ram Pickup 3500 Quad Cab 2009
Dodge Sprinter Cargo Van 2009
Dodge Journey SUV 2012
Dodge Dakota Crew Cab 2010
Dodge Dakota Club Cab 2007
Dodge Magnum Wagon 2008
Dodge Challenger SRT8 2011
Dodge Durango SUV 2012
Dodge Durango SUV 2007
Dodge Charger Sedan 2012
Dodge Charger SRT-8 2009
Eagle Talon Hatchback 1998
FIAT 500 Abarth 2012
FIAT 500 Convertible 2012
Ferrari FF Coupe 2012
Ferrari California Convertible 2012
Ferrari 458 Italia Convertible 2012
Ferrari 458 Italia Coupe 2012
Fisker Karma Sedan 2012
Ford F-450 Super Duty Crew Cab 2012
Ford Mustang Convertible 2007
Ford Freestar Minivan 2007
Ford Expedition EL SUV 2009
Ford Edge SUV 2012
Ford Ranger SuperCab 2011
Ford GT Coupe 2006
Ford F-150 Regular Cab 2012
Ford F-150 Regular Cab 2007
Ford Focus Sedan 2007
Ford E-Series Wagon Van 2012
Ford Fiesta Sedan 2012
GMC Terrain SUV 2012
GMC Savana Van 2012
GMC Yukon Hybrid SUV 2012
GMC Acadia SUV 2012
GMC Canyon Extended Cab 2012
Geo Metro Convertible 1993
HUMMER H3T Crew Cab 2010
HUMMER H2 SUT Crew Cab 2009
Honda Odyssey Minivan 2012
Honda Odyssey Minivan 2007
Honda Accord Coupe 2012
Honda Accord Sedan 2012
Hyundai Veloster Hatchback 2012
Hyundai Santa Fe SUV 2012
Hyundai Tucson SUV 2012
Hyundai Veracruz SUV 2012
Hyundai Sonata Hybrid Sedan 2012
Hyundai Elantra Sedan 2007
Hyundai Accent Sedan 2012
Hyundai Genesis Sedan 2012
Hyundai Sonata Sedan 2012
Hyundai Elantra Touring Hatchback 2012
Hyundai Azera Sedan 2012
Infiniti G Coupe IPL 2012
Infiniti QX56 SUV 2011
Isuzu Ascender SUV 2008
Jaguar XK XKR 2012
Jeep Patriot SUV 2012
Jeep Wrangler SUV 2012
Jeep Liberty SUV 2012
Jeep Grand Cherokee SUV 2012
Jeep Compass SUV 2012
Lamborghini Reventon Coupe 2008
Lamborghini Aventador Coupe 2012
Lamborghini Gallardo LP 570-4 Superleggera 2012
Lamborghini Diablo Coupe 2001
Land Rover Range Rover SUV 2012
Land Rover LR2 SUV 2012
Lincoln Town Car Sedan 2011
MINI Cooper Roadster Convertible 2012
Maybach Landaulet Convertible 2012
Mazda Tribute SUV 2011
McLaren MP4-12C Coupe 2012
Mercedes-Benz 300-Class Convertible 1993
Mercedes-Benz C-Class Sedan 2012
Mercedes-Benz SL-Class Coupe 2009
Mercedes-Benz E-Class Sedan 2012
Mercedes-Benz S-Class Sedan 2012
Mercedes-Benz Sprinter Van 2012
Mitsubishi Lancer Sedan 2012
Nissan Leaf Hatchback 2012
Nissan NV Passenger Van 2012
Nissan Juke Hatchback 2012
Nissan 240SX Coupe 1998
Plymouth Neon Coupe 1999
Porsche Panamera Sedan 2012
Ram C/V Cargo Van Minivan 2012
Rolls-Royce Phantom Drophead Coupe Convertible 2012
Rolls-Royce Ghost Sedan 2012
Rolls-Royce Phantom Sedan 2012
Scion xD Hatchback 2012
Spyker C8 Convertible 2009
Spyker C8 Coupe 2009
Suzuki Aerio Sedan 2007
Suzuki Kizashi Sedan 2012
Suzuki SX4 Hatchback 2012
Suzuki SX4 Sedan 2012
Tesla Model S Sedan 2012
Toyota Sequoia SUV 2012
Toyota Camry Sedan 2012
Toyota Corolla Sedan 2012
Toyota 4Runner SUV 2012
Volkswagen Golf Hatchback 2012
Volkswagen Golf Hatchback 1991
Volkswagen Beetle Hatchback 2012
Volvo C30 Hatchback 2012
Volvo 240 Sedan 1993
Volvo XC90 SUV 2007
smart fortwo Convertible 2012
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None:
self.datapipe = datapipe
def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]:
for _, file in self.datapipe:
data = read_mat(file, squeeze_me=True)
for ann in data["annotations"]:
yield tuple(ann) # type: ignore[misc]
NAME = "stanford-cars"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class StanfordCars(Dataset):
"""Stanford Cars dataset.
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
dependencies=scipy
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_URL_ROOT = "https://ai.stanford.edu/~jkrause/"
_URLS = {
"train": f"{_URL_ROOT}car196/cars_train.tgz",
"test": f"{_URL_ROOT}car196/cars_test.tgz",
"cars_test_annos_withlabels": f"{_URL_ROOT}car196/cars_test_annos_withlabels.mat",
"car_devkit": f"{_URL_ROOT}cars/car_devkit.tgz",
}
_CHECKSUM = {
"train": "b97deb463af7d58b6bfaa18b2a4de9829f0f79e8ce663dfa9261bf7810e9accd",
"test": "bffea656d6f425cba3c91c6d83336e4c5f86c6cffd8975b0f375d3a10da8e243",
"cars_test_annos_withlabels": "790f75be8ea34eeded134cc559332baf23e30e91367e9ddca97d26ed9b895f05",
"car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288",
}
def _resources(self) -> List[OnlineResource]:
resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])]
if self._split == "train":
resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"]))
else:
resources.append(
HttpResource(
self._URLS["cars_test_annos_withlabels"], sha256=self._CHECKSUM["cars_test_annos_withlabels"]
)
)
return resources
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, int, int, str]]) -> Dict[str, Any]:
image, target = data
path, buffer = image
image = EncodedImage.from_file(buffer)
return dict(
path=path,
image=image,
label=Label(target[4] - 1, categories=self._categories),
bounding_boxes=BoundingBoxes(target[:4], format="xyxy", spatial_size=image.spatial_size),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
if self._split == "train":
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
targets_dp = StanfordCarsLabelReader(targets_dp)
dp = Zipper(images_dp, targets_dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self) -> List[str]:
resources = self._resources()
devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat"))
_, meta_file = next(iter(meta_dp))
return list(read_mat(meta_file, squeeze_me=True)["class_names"])
def __len__(self) -> int:
return 8_144 if self._split == "train" else 8_041
import pathlib
from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "svhn"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class SVHN(Dataset):
"""SVHN Dataset.
homepage="http://ufldl.stanford.edu/housenumbers/",
dependencies = scipy
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_CHECKSUMS = {
"train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8",
"test": "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b",
"extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3",
}
def _resources(self) -> List[OnlineResource]:
data = HttpResource(
f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat",
sha256=self._CHECKSUMS[self._split],
)
return [data]
def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]:
_, buffer = data
content = read_mat(buffer)
return list(
zip(
content["X"].transpose((3, 0, 1, 2)),
content["y"].squeeze(),
)
)
def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]:
image_array, label_array = data
return dict(
image=Image(image_array.transpose((2, 0, 1))),
label=Label(int(label_array) % 10, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Mapper(dp, self._read_images_and_labels)
dp = UnBatcher(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 73_257,
"test": 26_032,
"extra": 531_131,
}[self._split]
import pathlib
from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "usps"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class USPS(Dataset):
"""USPS Dataset
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
_RESOURCES = {
"train": HttpResource(
f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f"
),
"test": HttpResource(
f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e"
),
}
def _resources(self) -> List[OnlineResource]:
return [USPS._RESOURCES[self._split]]
def _prepare_sample(self, line: str) -> Dict[str, Any]:
label, *values = line.strip().split(" ")
values = [float(value.split(":")[1]) for value in values]
pixels = torch.tensor(values).add_(1).div_(2)
return dict(
image=Image(pixels.reshape(16, 16)),
label=Label(int(label) - 1, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = Decompressor(resource_dps[0])
dp = LineReader(dp, decode=True, return_path=False)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 7_291 if self._split == "train" else 2_007
__background__
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import enum
import functools
import pathlib
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datasets import VOCDetection
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
NAME = "voc"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class VOC(Dataset):
"""
- **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2012",
task: str = "detection",
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
else:
self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._task = self._verify_str_arg(task, "task", ("detection", "segmentation"))
self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass"
self._split_folder = "Main" if task == "detection" else "Segmentation"
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
"2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"),
"2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"),
"2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"),
"2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"),
}
_TEST_ARCHIVES = {
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}
def _resources(self) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)
return [archive]
def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool:
path = pathlib.Path(data[0])
return name in path.parent.parts[-depth:]
class _Demux(enum.IntEnum):
SPLIT = 0
IMAGES = 1
ANNS = 2
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
if self._is_in_folder(data, name="ImageSets", depth=2):
return self._Demux.SPLIT
elif self._is_in_folder(data, name="JPEGImages"):
return self._Demux.IMAGES
elif self._is_in_folder(data, name=self._anns_folder):
return self._Demux.ANNS
else:
return None
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
buffer.close()
return ann
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
instances = anns["object"]
return dict(
bounding_boxes=BoundingBoxes(
[
[int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for instance in instances
],
format="xyxy",
spatial_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self._categories.index(instance["name"]) for instance in instances], categories=self._categories
),
)
def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return dict(segmentation=EncodedImage.from_file(buffer))
def _prepare_sample(
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
return dict(
(self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
dp = IterKeyZipper(
dp,
data_dp,
key_fn=getitem(*[0] * level, 1),
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2007", "detection"): 2_501,
("train", "2007", "segmentation"): 209,
("train", "2008", "detection"): 2_111,
("train", "2008", "segmentation"): 511,
("train", "2009", "detection"): 3_473,
("train", "2009", "segmentation"): 749,
("train", "2010", "detection"): 4_998,
("train", "2010", "segmentation"): 964,
("train", "2011", "detection"): 5_717,
("train", "2011", "segmentation"): 1_112,
("train", "2012", "detection"): 5_717,
("train", "2012", "segmentation"): 1_464,
("val", "2007", "detection"): 2_510,
("val", "2007", "segmentation"): 213,
("val", "2008", "detection"): 2_221,
("val", "2008", "segmentation"): 512,
("val", "2009", "detection"): 3_581,
("val", "2009", "segmentation"): 750,
("val", "2010", "detection"): 5_105,
("val", "2010", "segmentation"): 964,
("val", "2011", "detection"): 5_823,
("val", "2011", "segmentation"): 1_111,
("val", "2012", "detection"): 5_823,
("val", "2012", "segmentation"): 1_449,
("trainval", "2007", "detection"): 5_011,
("trainval", "2007", "segmentation"): 422,
("trainval", "2008", "detection"): 4_332,
("trainval", "2008", "segmentation"): 1_023,
("trainval", "2009", "detection"): 7_054,
("trainval", "2009", "segmentation"): 1_499,
("trainval", "2010", "detection"): 10_103,
("trainval", "2010", "segmentation"): 1_928,
("trainval", "2011", "detection"): 11_540,
("trainval", "2011", "segmentation"): 2_223,
("trainval", "2012", "detection"): 11_540,
("trainval", "2012", "segmentation"): 2_913,
("test", "2007", "detection"): 4_952,
("test", "2007", "segmentation"): 210,
}[(self._split, self._year, self._task)]
def _filter_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == self._Demux.ANNS
def _generate_categories(self) -> List[str]:
self._task = "detection"
resources = self._resources()
archive_dp = resources[0].load(self._root)
dp = Filter(archive_dp, self._filter_anns)
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
categories = sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
# We add a background category to be used during segmentation
categories.insert(0, "__background__")
return categories
import functools
import os
import os.path
import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import EncodedData, EncodedImage
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
__all__ = ["from_data_folder", "from_image_folder"]
def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
rel_path = pathlib.Path(path).relative_to(root)
return rel_path.is_dir() or rel_path.parent != pathlib.Path(".")
def _prepare_sample(
data: Tuple[str, BinaryIO],
*,
root: pathlib.Path,
categories: List[str],
) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).relative_to(root).parts[0]
return dict(
path=path,
data=EncodedData.from_file(buffer),
label=Label.from_category(category, categories=categories),
)
def from_data_folder(
root: Union[str, pathlib.Path],
*,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
) -> Tuple[IterDataPipe, List[str]]:
root = pathlib.Path(root).expanduser().resolve()
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = FileOpener(dp, mode="rb")
return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories
def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = EncodedImage(sample.pop("data").data)
return sample
def from_image_folder(
root: Union[str, pathlib.Path],
*,
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
**kwargs: Any,
) -> Tuple[IterDataPipe, List[str]]:
valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())]
dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs)
return Mapper(dp, _data_to_image_key), categories
import os
from typing import Optional
import torchvision._internally_replaced_utils as _iru
def home(root: Optional[str] = None) -> str:
if root is not None:
_iru._HOME = root
return _iru._HOME
root = os.getenv("TORCHVISION_DATASETS_HOME")
if root is not None:
return root
return _iru._HOME
def use_sharded_dataset(use: Optional[bool] = None) -> bool:
if use is not None:
_iru._USE_SHARDED_DATASETS = use
return _iru._USE_SHARDED_DATASETS
use = os.getenv("TORCHVISION_SHARDED_DATASETS")
if use is not None:
return use == "1"
return _iru._USE_SHARDED_DATASETS
# type: ignore
import argparse
import collections.abc
import contextlib
import inspect
import itertools
import os
import os.path
import pathlib
import shutil
import sys
import tempfile
import time
import unittest.mock
import warnings
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader_experimental import DataLoader2
from torchvision import datasets as legacy_datasets
from torchvision.datasets.utils import extract_archive
from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import PILToTensor
def main(
name,
*,
variant=None,
legacy=True,
new=True,
start=True,
iteration=True,
num_starts=3,
num_samples=10_000,
temp_root=None,
num_workers=0,
):
benchmarks = [
benchmark
for benchmark in DATASET_BENCHMARKS
if benchmark.name == name and (variant is None or benchmark.variant == variant)
]
if not benchmarks:
msg = f"No DatasetBenchmark available for dataset '{name}'"
if variant is not None:
msg += f" and variant '{variant}'"
raise ValueError(msg)
for benchmark in benchmarks:
print("#" * 80)
print(f"{benchmark.name}" + (f" ({benchmark.variant})" if benchmark.variant is not None else ""))
if legacy and start:
print(
"legacy",
"cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
)
print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts),
)
if legacy and iteration:
print(
"legacy",
"iteration",
Measurement.iterations_per_time(
benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples)
),
)
if new and start:
print(
"new",
"cold_start",
Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts),
)
if new and iteration:
print(
"new",
"iteration",
Measurement.iterations_per_time(
benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)
),
)
class DatasetBenchmark:
def __init__(
self,
name: str,
*,
variant=None,
legacy_cls=None,
new_config=None,
legacy_config_map=None,
legacy_special_options_map=None,
prepare_legacy_root=None,
):
self.name = name
self.variant = variant
self.new_raw_dataset = new_datasets._api.find(name)
self.legacy_cls = legacy_cls or self._find_legacy_cls()
if new_config is None:
new_config = self.new_raw_dataset.default_config
elif isinstance(new_config, dict):
new_config = self.new_raw_dataset.info.make_config(**new_config)
self.new_config = new_config
self.legacy_config_map = legacy_config_map
self.legacy_special_options_map = legacy_special_options_map or self._legacy_special_options_map
self.prepare_legacy_root = prepare_legacy_root
def new_dataset(self, *, num_workers=0):
return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers)
def new_cold_start(self, *, num_workers):
def fn(timer):
with timer:
dataset = self.new_dataset(num_workers=num_workers)
next(iter(dataset))
return fn
def new_iteration(self, *, num_samples, num_workers):
def fn(timer):
dataset = self.new_dataset(num_workers=num_workers)
num_sample = 0
with timer:
for _ in dataset:
num_sample += 1
if num_sample == num_samples:
break
return num_sample
return fn
def suppress_output(self):
@contextlib.contextmanager
def context_manager():
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
yield
return context_manager()
def legacy_dataset(self, root, *, num_workers=0, download=None):
legacy_config = self.legacy_config_map(self, root) if self.legacy_config_map else dict()
special_options = self.legacy_special_options_map(self)
if "download" in special_options and download is not None:
special_options["download"] = download
with self.suppress_output():
return DataLoader(
self.legacy_cls(legacy_config.pop("root", str(root)), **legacy_config, **special_options),
shuffle=True,
num_workers=num_workers,
)
@contextlib.contextmanager
def patch_download_and_integrity_checks(self):
patches = [
("download_url", dict()),
("download_file_from_google_drive", dict()),
("check_integrity", dict(new=lambda path, md5=None: os.path.isfile(path))),
]
dataset_module = sys.modules[self.legacy_cls.__module__]
utils_module = legacy_datasets.utils
with contextlib.ExitStack() as stack:
for name, patch_kwargs in patches:
patch_module = dataset_module if name in dir(dataset_module) else utils_module
stack.enter_context(unittest.mock.patch(f"{patch_module.__name__}.{name}", **patch_kwargs))
yield stack
def _find_resource_file_names(self):
info = self.new_raw_dataset.info
valid_options = info._valid_options
file_names = set()
for options in (
dict(zip(valid_options.keys(), values)) for values in itertools.product(*valid_options.values())
):
resources = self.new_raw_dataset.resources(info.make_config(**options))
file_names.update([resource.file_name for resource in resources])
return file_names
@contextlib.contextmanager
def legacy_root(self, temp_root):
new_root = pathlib.Path(new_datasets.home()) / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
warnings.warn(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories.",
RuntimeWarning,
)
try:
for file_name in self._find_resource_file_names():
(legacy_root / file_name).symlink_to(new_root / file_name)
if self.prepare_legacy_root:
self.prepare_legacy_root(self, legacy_root)
with self.patch_download_and_integrity_checks():
yield legacy_root
finally:
shutil.rmtree(legacy_root)
def legacy_cold_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers)
next(iter(dataset))
return fn
def legacy_warm_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
self.legacy_dataset(root, num_workers=num_workers)
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers, download=False)
next(iter(dataset))
return fn
def legacy_iteration(self, temp_root, *, num_samples, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer:
for num_sample, _ in enumerate(dataset, 1):
if num_sample == num_samples:
break
return num_sample
return fn
def _find_legacy_cls(self):
legacy_clss = {
name.lower(): dataset_class
for name, dataset_class in legacy_datasets.__dict__.items()
if isinstance(dataset_class, type) and issubclass(dataset_class, legacy_datasets.VisionDataset)
}
try:
return legacy_clss[self.name]
except KeyError as error:
raise RuntimeError(
f"Can't determine the legacy dataset class for '{self.name}' automatically. "
f"Please set the 'legacy_cls' keyword argument manually."
) from error
_SPECIAL_KWARGS = {
"transform",
"target_transform",
"transforms",
"download",
}
@staticmethod
def _legacy_special_options_map(benchmark):
available_parameters = set()
for cls in benchmark.legacy_cls.__mro__:
if cls is legacy_datasets.VisionDataset:
break
available_parameters.update(inspect.signature(cls.__init__).parameters)
available_special_kwargs = benchmark._SPECIAL_KWARGS.intersection(available_parameters)
special_options = dict()
if "download" in available_special_kwargs:
special_options["download"] = True
if "transform" in available_special_kwargs:
special_options["transform"] = PILToTensor()
if "target_transform" in available_special_kwargs:
special_options["target_transform"] = torch.tensor
elif "transforms" in available_special_kwargs:
special_options["transforms"] = JointTransform(PILToTensor(), PILToTensor())
return special_options
class Measurement:
@classmethod
def time(cls, fn, *, number):
results = Measurement._timeit(fn, number=number)
times = torch.tensor(tuple(zip(*results))[1])
return cls._format(times, unit="s")
@classmethod
def iterations_per_time(cls, fn):
num_samples, time = Measurement._timeit(fn, number=1)[0]
iterations_per_second = torch.tensor(num_samples) / torch.tensor(time)
return cls._format(iterations_per_second, unit="it/s")
class Timer:
def __init__(self):
self._start = None
self._stop = None
def __enter__(self):
self._start = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop = time.perf_counter()
@property
def delta(self):
if self._start is None:
raise RuntimeError()
elif self._stop is None:
raise RuntimeError()
return self._stop - self._start
@classmethod
def _timeit(cls, fn, number):
results = []
for _ in range(number):
timer = cls.Timer()
output = fn(timer)
results.append((output, timer.delta))
return results
@classmethod
def _format(cls, measurements, *, unit):
measurements = torch.as_tensor(measurements).to(torch.float64).flatten()
if measurements.numel() == 1:
# TODO format that into engineering format
return f"{float(measurements):.3f} {unit}"
mean, std = Measurement._compute_mean_and_std(measurements)
# TODO format that into engineering format
return f"{mean:.3f} ± {std:.3f} {unit}"
@classmethod
def _compute_mean_and_std(cls, t):
mean = float(t.mean())
std = float(t.std(0, unbiased=t.numel() > 1))
return mean, std
def no_split(benchmark, root):
legacy_config = dict(benchmark.new_config)
del legacy_config["split"]
return legacy_config
def bool_split(name="train"):
def legacy_config_map(benchmark, root):
legacy_config = dict(benchmark.new_config)
legacy_config[name] = legacy_config.pop("split") == "train"
return legacy_config
return legacy_config_map
def base_folder(rel_folder=None):
if rel_folder is None:
def rel_folder(benchmark):
return benchmark.name
elif not callable(rel_folder):
name = rel_folder
def rel_folder(_):
return name
def prepare_legacy_root(benchmark, root):
files = list(root.glob("*"))
folder = root / rel_folder(benchmark)
folder.mkdir(parents=True)
for file in files:
shutil.move(str(file), str(folder))
return folder
return prepare_legacy_root
class JointTransform:
def __init__(self, *transforms):
self.transforms = transforms
def __call__(self, *inputs):
if len(inputs) == 1 and isinstance(inputs, collections.abc.Sequence):
inputs = inputs[0]
if len(inputs) != len(self.transforms):
raise RuntimeError(
f"The number of inputs and transforms mismatches: {len(inputs)} != {len(self.transforms)}."
)
return tuple(transform(input) for transform, input in zip(self.transforms, inputs))
def caltech101_legacy_config_map(benchmark, root):
legacy_config = no_split(benchmark, root)
# The new dataset always returns the category and annotation
legacy_config["target_type"] = ("category", "annotation")
return legacy_config
mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw")
def mnist_legacy_config_map(benchmark, root):
return dict(train=benchmark.new_config.split == "train")
def emnist_prepare_legacy_root(benchmark, root):
folder = mnist_base_folder(benchmark, root)
shutil.move(str(folder / "emnist-gzip.zip"), str(folder / "gzip.zip"))
return folder
def emnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["split"] = benchmark.new_config.image_set.replace("_", "").lower()
return legacy_config
def qmnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["what"] = benchmark.new_config.split
# The new dataset always returns the full label
legacy_config["compat"] = False
return legacy_config
def coco_legacy_config_map(benchmark, root):
images, _ = benchmark.new_raw_dataset.resources(benchmark.new_config)
return dict(
root=str(root / pathlib.Path(images.file_name).stem),
annFile=str(
root / "annotations" / f"{benchmark.variant}_{benchmark.new_config.split}{benchmark.new_config.year}.json"
),
)
def coco_prepare_legacy_root(benchmark, root):
images, annotations = benchmark.new_raw_dataset.resources(benchmark.new_config)
extract_archive(str(root / images.file_name))
extract_archive(str(root / annotations.file_name))
DATASET_BENCHMARKS = [
DatasetBenchmark(
"caltech101",
legacy_config_map=caltech101_legacy_config_map,
prepare_legacy_root=base_folder(),
legacy_special_options_map=lambda config: dict(
download=True,
transform=PILToTensor(),
target_transform=JointTransform(torch.tensor, torch.tensor),
),
),
DatasetBenchmark(
"caltech256",
legacy_config_map=no_split,
prepare_legacy_root=base_folder(),
),
DatasetBenchmark(
"celeba",
prepare_legacy_root=base_folder(),
legacy_config_map=lambda benchmark: dict(
split="valid" if benchmark.new_config.split == "val" else benchmark.new_config.split,
# The new dataset always returns all annotations
target_type=("attr", "identity", "bbox", "landmarks"),
),
),
DatasetBenchmark(
"cifar10",
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"cifar100",
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"emnist",
prepare_legacy_root=emnist_prepare_legacy_root,
legacy_config_map=emnist_legacy_config_map,
),
DatasetBenchmark(
"fashionmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"kmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"mnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"qmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"sbd",
legacy_cls=legacy_datasets.SBDataset,
legacy_config_map=lambda benchmark: dict(
image_set=benchmark.new_config.split,
mode="boundaries" if benchmark.new_config.boundaries else "segmentation",
),
legacy_special_options_map=lambda benchmark: dict(
download=True,
transforms=JointTransform(
PILToTensor(), torch.tensor if benchmark.new_config.boundaries else PILToTensor()
),
),
),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet),
DatasetBenchmark(
"coco",
variant="instances",
legacy_cls=legacy_datasets.CocoDetection,
new_config=dict(split="train", annotations="instances"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
DatasetBenchmark(
"coco",
variant="captions",
legacy_cls=legacy_datasets.CocoCaptions,
new_config=dict(split="train", annotations="captions"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
]
def parse_args(argv=None):
parser = argparse.ArgumentParser(
prog="torchvision.prototype.datasets.benchmark.py",
description="Utility to benchmark new datasets against their legacy variants.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"--variant", help="Variant of the dataset. If omitted all available variants will be benchmarked."
)
parser.add_argument(
"-n",
"--num-starts",
type=int,
default=3,
help="Number of warm and cold starts of each benchmark. Default to 3.",
)
parser.add_argument(
"-N",
"--num-samples",
type=int,
default=10_000,
help="Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000.",
)
parser.add_argument(
"--nl",
"--no-legacy",
dest="legacy",
action="store_false",
help="Skip legacy benchmarks.",
)
parser.add_argument(
"--nn",
"--no-new",
dest="new",
action="store_false",
help="Skip new benchmarks.",
)
parser.add_argument(
"--ns",
"--no-start",
dest="start",
action="store_false",
help="Skip start benchmarks.",
)
parser.add_argument(
"--ni",
"--no-iteration",
dest="iteration",
action="store_false",
help="Skip iteration benchmarks.",
)
parser.add_argument(
"-t",
"--temp-root",
type=pathlib.Path,
help=(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser.add_argument(
"-j",
"--num-workers",
type=int,
default=0,
help=(
"Number of subprocesses used to load the data. Setting this to 0 (default) will load all data in the main "
"process and thus disable multi-processing."
),
)
return parser.parse_args(argv or sys.argv[1:])
if __name__ == "__main__":
args = parse_args()
try:
main(
args.name,
variant=args.variant,
legacy=args.legacy,
new=args.new,
start=args.start,
iteration=args.iteration,
num_starts=args.num_starts,
num_samples=args.num_samples,
temp_root=args.temp_root,
num_workers=args.num_workers,
)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
# type: ignore
import argparse
import csv
import sys
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
for name in names:
path = BUILTIN_DIR / f"{name}.categories"
if path.exists() and not force:
continue
dataset = datasets.load(name)
try:
categories = dataset._generate_categories()
except NotImplementedError:
continue
with open(path, "w") as file:
writer = csv.writer(file, lineterminator="\n")
for category in categories:
writer.writerow((category,) if isinstance(category, str) else category)
def parse_args(argv=None):
parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py")
parser.add_argument(
"names",
nargs="*",
type=str,
help="Names of datasets to generate category files for. If omitted, all datasets will be used.",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Force regeneration of category files.",
)
args = parser.parse_args(argv or sys.argv[1:])
if not args.names:
args.names = datasets.list_datasets()
return args
if __name__ == "__main__":
args = parse_args()
try:
main(*args.names, force=args.force)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
from . import _internal # usort: skip
from ._dataset import Dataset
from ._encoded import EncodedData, EncodedImage
from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource
import abc
import importlib
import pathlib
from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union
from torchdata.datapipes.iter import IterDataPipe
from torchvision.datasets.utils import verify_str_arg
from ._resource import OnlineResource
class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC):
@staticmethod
def _verify_str_arg(
value: str,
arg: Optional[str] = None,
valid_values: Optional[Collection[str]] = None,
*,
custom_msg: Optional[str] = None,
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
def __init__(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
) -> None:
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from None
self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
]
self._dp = self._datapipe(resources)
def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp
@abc.abstractmethod
def _resources(self) -> List[OnlineResource]:
pass
@abc.abstractmethod
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
pass
@abc.abstractmethod
def __len__(self) -> int:
pass
def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError
from __future__ import annotations
import os
import sys
from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from torchvision.tv_tensors._tv_tensor import TVTensor
D = TypeVar("D", bound="EncodedData")
class EncodedData(TVTensor):
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return cls._wrap(tensor)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return cls._wrap(tensor)
@classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
file.close()
return encoded_data
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
with open(path, "rb") as file:
return cls.from_file(file, **kwargs)
class EncodedImage(EncodedData):
# TODO: Use @functools.cached_property if we can depend on Python 3.8
@property
def spatial_size(self) -> Tuple[int, int]:
if not hasattr(self, "_spatial_size"):
with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image:
self._spatial_size = image.height, image.width
return self._spatial_size
import csv
import functools
import pathlib
import pickle
from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchvision.prototype.utils._internal import fromfile
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"read_mat",
"MappingIterator",
"getitem",
"path_accessor",
"path_comparator",
"read_flo",
"hint_sharding",
"hint_shuffling",
]
K = TypeVar("K")
D = TypeVar("D")
# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
try:
import scipy.io as sio
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
data = sio.loadmat(buffer, **kwargs)
buffer.close()
return data
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
self.drop_key = drop_key
def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
for mapping in self.datapipe:
yield from iter(mapping.values() if self.drop_key else mapping.items())
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
for item in items:
obj = obj[item]
return obj
def getitem(*items: Any) -> Callable[[Any], Any]:
return functools.partial(_getitem_closure, items=items)
def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
for attr in attrs:
obj = getattr(obj, attr)
return obj
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
return _getattr_closure(path, attrs=name.split("."))
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
return getter(pathlib.Path(data[0]))
def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]:
if isinstance(getter, str):
getter = functools.partial(_path_attribute_accessor, name=getter)
return functools.partial(_path_accessor_closure, getter=getter)
def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool:
return accessor(data) == value
def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]:
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
class PicklerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None:
self.source_datapipe = source_datapipe
def __iter__(self) -> Iterator[Any]:
for _, fobj in self.source_datapipe:
data = pickle.load(fobj)
for _, d in enumerate(data):
yield d
class SharderDataPipe(ShardingFilter):
def __init__(self, source_datapipe: IterDataPipe) -> None:
super().__init__(source_datapipe)
self.rank = 0
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.apply_sharding(self.world_size, self.rank)
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_id = self.rank
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_id + worker_info.id * num_workers
num_workers *= worker_info.num_workers
self.apply_sharding(num_workers, worker_id)
yield from super().__iter__()
class TakerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe, num_take: int) -> None:
super().__init__()
self.source_datapipe = source_datapipe
self.num_take = num_take
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.world_size = dist.get_world_size()
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers *= worker_info.num_workers
# TODO: this is weird as it drops more elements than it should
num_take = self.num_take // num_workers
for i, data in enumerate(self.source_datapipe):
if i < num_take:
yield data
else:
break
def __len__(self) -> int:
num_take = self.num_take // self.world_size
if isinstance(self.source_datapipe, Sized):
if len(self.source_datapipe) < num_take:
num_take = len(self.source_datapipe)
# TODO: might be weird to not take `num_workers` into account
return num_take
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[str, Any]]:
dp = IoPathFileLister(root=root)
dp = SharderDataPipe(dp)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE)
dp = IoPathFileOpener(dp, mode="rb")
dp = PicklerDataPipe(dp)
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp
def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False)
def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]:
path = BUILTIN_DIR / f"{name}.categories"
with open(path, newline="") as file:
rows = list(csv.reader(file))
rows = [row[0] if len(row) == 1 else row for row in rows]
return rows
import abc
import hashlib
import itertools
import pathlib
from typing import Any, Callable, IO, Literal, NoReturn, Optional, Sequence, Set, Tuple, Union
from urllib.parse import urlparse
from torchdata.datapipes.iter import (
FileLister,
FileOpener,
IterableWrapper,
IterDataPipe,
RarArchiveLoader,
TarArchiveLoader,
ZipArchiveLoader,
)
from torchvision.datasets.utils import (
_decompress,
_detect_file_type,
_get_google_drive_file_id,
_get_redirect_url,
download_file_from_google_drive,
download_url,
extract_archive,
)
class OnlineResource(abc.ABC):
def __init__(
self,
*,
file_name: str,
sha256: Optional[str] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
if isinstance(preprocess, str):
if preprocess == "decompress":
preprocess = self._decompress
elif preprocess == "extract":
preprocess = self._extract
else:
raise ValueError(
f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string,"
f"but got {preprocess} instead."
)
self._preprocess = preprocess
@staticmethod
def _extract(file: pathlib.Path) -> None:
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
@staticmethod
def _decompress(file: pathlib.Path) -> None:
_decompress(str(file), remove_finished=True)
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
return FileOpener(FileLister(str(path), recursive=True), mode="rb")
dp = FileOpener(IterableWrapper((str(path),)), mode="rb")
archive_loader = self._guess_archive_loader(path)
if archive_loader:
dp = archive_loader(dp)
return dp
_ARCHIVE_LOADERS = {
".tar": TarArchiveLoader,
".zip": ZipArchiveLoader,
".rar": RarArchiveLoader,
}
def _guess_archive_loader(
self, path: pathlib.Path
) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]:
try:
_, archive_type, _ = _detect_file_type(path.name)
except RuntimeError:
return None
return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type]
def load(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False
) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root)
path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
stem = path.name.replace("".join(path.suffixes), "")
def find_candidates() -> Set[pathlib.Path]:
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
# test split of the stanford-cars dataset uses the files
# - cars_test.tgz
# - cars_test_annos_withlabels.mat
# Globbing for `"cars_test*"` picks up both.
candidates = {file for file in path.parent.glob(f"{stem}.*")}
folder_candidate = path.parent / stem
if folder_candidate.exists():
candidates.add(folder_candidate)
return candidates
candidates = find_candidates()
if not candidates:
self.download(root, skip_integrity_check=skip_integrity_check)
if self._preprocess is not None:
self._preprocess(path)
candidates = find_candidates()
# We use the path with the fewest suffixes. This gives us the
# extracted > decompressed > raw
# priority that we want for the best I/O performance.
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))
@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
pass
def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
root = pathlib.Path(root)
self._download(root)
path = root / self.file_name
if self.sha256 and not skip_integrity_check:
self._check_sha256(path)
return path
def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None:
hash = hashlib.sha256()
with open(path, "rb") as file:
while chunk := file.read(chunk_size):
hash.update(chunk)
sha256 = hash.hexdigest()
if sha256 != self.sha256:
raise RuntimeError(
f"After the download, the SHA256 checksum of {path} didn't match the expected one: "
f"{sha256} != {self.sha256}"
)
class HttpResource(OnlineResource):
def __init__(
self, url: str, *, file_name: Optional[str] = None, mirrors: Sequence[str] = (), **kwargs: Any
) -> None:
super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)
self.url = url
self.mirrors = mirrors
self._resolved = False
def resolve(self) -> OnlineResource:
if self._resolved:
return self
redirect_url = _get_redirect_url(self.url)
if redirect_url == self.url:
self._resolved = True
return self
meta = {
attr.lstrip("_"): getattr(self, attr)
for attr in (
"file_name",
"sha256",
"_preprocess",
)
}
gdrive_id = _get_google_drive_file_id(redirect_url)
if gdrive_id:
return GDriveResource(gdrive_id, **meta)
http_resource = HttpResource(redirect_url, **meta)
http_resource._resolved = True
return http_resource
def _download(self, root: pathlib.Path) -> None:
if not self._resolved:
return self.resolve()._download(root)
for url in itertools.chain((self.url,), self.mirrors):
try:
download_url(url, str(root), filename=self.file_name, md5=None)
# TODO: make this more precise
except Exception:
continue
return
else:
# TODO: make this more informative
raise RuntimeError("Download failed!")
class GDriveResource(OnlineResource):
def __init__(self, id: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.id = id
def _download(self, root: pathlib.Path) -> None:
download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None)
class ManualDownloadResource(OnlineResource):
def __init__(self, instructions: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.instructions = instructions
def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError(
f"The file {self.file_name} cannot be downloaded automatically. "
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)
class KaggleDownloadResource(ManualDownloadResource):
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
instructions = "\n".join(
(
"1. Register and login at https://www.kaggle.com",
f"2. Navigate to {challenge_url}",
"3. Click 'Join Competition' and follow the instructions there",
"4. Navigate to the 'Data' tab",
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
)
)
super().__init__(instructions, file_name=file_name, **kwargs)
from .raft_stereo import *
from .crestereo import *
import math
from functools import partial
from typing import Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
all = (
"CREStereo",
"CREStereo_Base_Weights",
"crestereo_base",
)
class ConvexMaskPredictor(nn.Module):
def __init__(
self,
*,
in_channels: int,
hidden_size: int,
upsample_factor: int,
multiplier: float = 0.25,
) -> None:
super().__init__()
self.mask_head = nn.Sequential(
Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3),
# https://arxiv.org/pdf/2003.12039.pdf (Annex section B) for the
# following convolution output size
nn.Conv2d(hidden_size, upsample_factor**2 * 9, 1, padding=0),
)
self.multiplier = multiplier
def forward(self, x: Tensor) -> Tensor:
x = self.mask_head(x) * self.multiplier
return x
def get_correlation(
left_feature: Tensor,
right_feature: Tensor,
window_size: Tuple[int, int] = (3, 3),
dilate: Tuple[int, int] = (1, 1),
) -> Tensor:
"""Function that computes a correlation product between the left and right features.
The correlation is computed in a sliding window fashion, namely the left features are fixed
and for each ``(i, j)`` location we compute the correlation with a sliding window anchored in
``(i, j)`` from the right feature map. The sliding window selects pixels obtained in the range of the sliding
window; i.e ``(i - window_size // 2, i + window_size // 2)`` respectively ``(j - window_size // 2, j + window_size // 2)``.
"""
B, C, H, W = left_feature.shape
di_y, di_x = dilate[0], dilate[1]
pad_y, pad_x = window_size[0] // 2 * di_y, window_size[1] // 2 * di_x
right_padded = F.pad(right_feature, (pad_x, pad_x, pad_y, pad_y), mode="replicate")
# in order to vectorize the correlation computation over all pixel candidates
# we create multiple shifted right images which we stack on an extra dimension
right_padded = F.unfold(right_padded, kernel_size=(H, W), dilation=dilate)
# torch unfold returns a tensor of shape [B, flattened_values, n_selections]
right_padded = right_padded.permute(0, 2, 1)
# we consider rehsape back into [B, n_views, C, H, W]
right_padded = right_padded.reshape(B, (window_size[0] * window_size[1]), C, H, W)
# we expand the left features for broadcasting
left_feature = left_feature.unsqueeze(1)
# this will compute an element product of between [B, 1, C, H, W] * [B, n_views, C, H, W]
# to obtain correlations over the pixel candidates we perform a mean on the C dimension
correlation = torch.mean(left_feature * right_padded, dim=2, keepdim=False)
# the final correlation tensor shape will be [B, n_views, H, W]
# where on the i-th position of the n_views dimension we will have
# the correlation value between the left pixel
# and the i-th candidate on the right feature map
return correlation
def _check_window_specs(
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
if not np.prod(search_window_1d) == np.prod(search_window_2d):
raise ValueError(
f"The 1D and 2D windows should contain the same number of elements. "
f"1D shape: {search_window_1d} 2D shape: {search_window_2d}"
)
if not np.prod(search_window_1d) % 2 == 1:
raise ValueError(
f"Search windows should contain an odd number of elements in them."
f"Window of shape {search_window_1d} has {np.prod(search_window_1d)} elements."
)
if not any(size == 1 for size in search_window_1d):
raise ValueError(f"The 1D search window should have at least one size equal to 1. 1D shape: {search_window_1d}")
if any(size == 1 for size in search_window_2d):
raise ValueError(
f"The 2D search window should have all dimensions greater than 1. 2D shape: {search_window_2d}"
)
if any(dilate < 1 for dilate in search_dilate_1d):
raise ValueError(
f"The 1D search dilation should have all elements equal or greater than 1. 1D shape: {search_dilate_1d}"
)
if any(dilate < 1 for dilate in search_dilate_2d):
raise ValueError(
f"The 2D search dilation should have all elements equal greater than 1. 2D shape: {search_dilate_2d}"
)
class IterativeCorrelationLayer(nn.Module):
def __init__(
self,
groups: int = 4,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
_check_window_specs(
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
self.search_pixels = np.prod(search_window_1d)
self.groups = groups
# two selection tables for dealing with the small_patch argument in the forward function
self.patch_sizes = {
"2d": [search_window_2d for _ in range(self.groups)],
"1d": [search_window_1d for _ in range(self.groups)],
}
self.dilate_sizes = {
"2d": [search_dilate_2d for _ in range(self.groups)],
"1d": [search_dilate_1d for _ in range(self.groups)],
}
def forward(self, left_feature: Tensor, right_feature: Tensor, flow: Tensor, window_type: str = "1d") -> Tensor:
"""Function that computes 1 pass of non-offsetted Group-Wise correlation"""
coords = make_coords_grid(
left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device)
)
# we offset the coordinate grid in the flow direction
coords = coords + flow
coords = coords.permute(0, 2, 3, 1)
# resample right features according to off-setted grid
right_feature = grid_sample(right_feature, coords, mode="bilinear", align_corners=True)
# use_small_patch is a flag by which we decide on how many axes
# we perform candidate search. See section 3.1 ``Deformable search window`` & Figure 4 in the paper.
patch_size_list = self.patch_sizes[window_type]
dilate_size_list = self.dilate_sizes[window_type]
# chunking the left and right feature to perform group-wise correlation
# mechanism similar to GroupNorm. See section 3.1 ``Group-wise correlation``.
left_groups = torch.chunk(left_feature, self.groups, dim=1)
right_groups = torch.chunk(right_feature, self.groups, dim=1)
correlations = []
# this boils down to rather than performing the correlation product
# over the entire C dimensions, we use subsets of C to get multiple correlation sets
for i in range(len(patch_size_list)):
correlation = get_correlation(left_groups[i], right_groups[i], patch_size_list[i], dilate_size_list[i])
correlations.append(correlation)
final_correlations = torch.cat(correlations, dim=1)
return final_correlations
class AttentionOffsetCorrelationLayer(nn.Module):
def __init__(
self,
groups: int = 4,
attention_module: Optional[nn.Module] = None,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
_check_window_specs(
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
# convert to python scalar
self.search_pixels = int(np.prod(search_window_1d))
self.groups = groups
# two selection tables for dealing with the small_patch argument in the forward function
self.patch_sizes = {
"2d": [search_window_2d for _ in range(self.groups)],
"1d": [search_window_1d for _ in range(self.groups)],
}
self.dilate_sizes = {
"2d": [search_dilate_2d for _ in range(self.groups)],
"1d": [search_dilate_1d for _ in range(self.groups)],
}
self.attention_module = attention_module
def forward(
self,
left_feature: Tensor,
right_feature: Tensor,
flow: Tensor,
extra_offset: Tensor,
window_type: str = "1d",
) -> Tensor:
"""Function that computes 1 pass of offsetted Group-Wise correlation
If the class was provided with an attention layer, the left and right feature maps
will be passed through a transformer first
"""
B, C, H, W = left_feature.shape
if self.attention_module is not None:
# prepare for transformer required input shapes
left_feature = left_feature.permute(0, 2, 3, 1).reshape(B, H * W, C)
right_feature = right_feature.permute(0, 2, 3, 1).reshape(B, H * W, C)
# this can be either self attention or cross attention, hence the tuple return
left_feature, right_feature = self.attention_module(left_feature, right_feature)
left_feature = left_feature.reshape(B, H, W, C).permute(0, 3, 1, 2)
right_feature = right_feature.reshape(B, H, W, C).permute(0, 3, 1, 2)
left_groups = torch.chunk(left_feature, self.groups, dim=1)
right_groups = torch.chunk(right_feature, self.groups, dim=1)
num_search_candidates = self.search_pixels
# for each pixel (i, j) we have a number of search candidates
# thus, for each candidate we should have an X-axis and Y-axis offset value
extra_offset = extra_offset.reshape(B, num_search_candidates, 2, H, W).permute(0, 1, 3, 4, 2)
patch_size_list = self.patch_sizes[window_type]
dilate_size_list = self.dilate_sizes[window_type]
group_channels = C // self.groups
correlations = []
for i in range(len(patch_size_list)):
left_group, right_group = left_groups[i], right_groups[i]
patch_size, dilate = patch_size_list[i], dilate_size_list[i]
di_y, di_x = dilate
ps_y, ps_x = patch_size
# define the search based on the window patch shape
ry, rx = ps_y // 2 * di_y, ps_x // 2 * di_x
# base offsets for search (i.e. where to look on the search index)
x_grid, y_grid = torch.meshgrid(
torch.arange(-rx, rx + 1, di_x), torch.arange(-ry, ry + 1, di_y), indexing="xy"
)
x_grid, y_grid = x_grid.to(flow.device), y_grid.to(flow.device)
offsets = torch.stack((x_grid, y_grid))
offsets = offsets.reshape(2, -1).permute(1, 0)
for d in (0, 2, 3):
offsets = offsets.unsqueeze(d)
# extra offsets for search (i.e. deformed search indexes. Similar concept to deformable convolutions)
offsets = offsets + extra_offset
coords = (
make_coords_grid(
left_feature.shape[0], left_feature.shape[2], left_feature.shape[3], device=str(left_feature.device)
)
+ flow
)
coords = coords.permute(0, 2, 3, 1).unsqueeze(1)
coords = coords + offsets
coords = coords.reshape(B, -1, W, 2)
right_group = grid_sample(right_group, coords, mode="bilinear", align_corners=True)
# we do not need to perform any window shifting because the grid sample op
# will return a multi-view right based on the num_search_candidates dimension in the offsets
right_group = right_group.reshape(B, group_channels, -1, H, W)
left_group = left_group.reshape(B, group_channels, -1, H, W)
correlation = torch.mean(left_group * right_group, dim=1)
correlations.append(correlation)
final_correlation = torch.cat(correlations, dim=1)
return final_correlation
class AdaptiveGroupCorrelationLayer(nn.Module):
"""
Container for computing various correlation types between a left and right feature map.
This module does not contain any optimisable parameters, it's solely a collection of ops.
We wrap in a nn.Module for torch.jit.script compatibility
Adaptive Group Correlation operations from: https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf
Canonical reference implementation: https://github.com/megvii-research/CREStereo/blob/master/nets/corr.py
"""
def __init__(
self,
iterative_correlation_layer: IterativeCorrelationLayer,
attention_offset_correlation_layer: AttentionOffsetCorrelationLayer,
) -> None:
super().__init__()
self.iterative_correlation_layer = iterative_correlation_layer
self.attention_offset_correlation_layer = attention_offset_correlation_layer
def forward(
self,
left_features: Tensor,
right_features: Tensor,
flow: torch.Tensor,
extra_offset: Optional[Tensor],
window_type: str = "1d",
iter_mode: bool = False,
) -> Tensor:
if iter_mode or extra_offset is None:
corr = self.iterative_correlation_layer(left_features, right_features, flow, window_type)
else:
corr = self.attention_offset_correlation_layer(
left_features, right_features, flow, extra_offset, window_type
) # type: ignore
return corr
def elu_feature_map(x: Tensor) -> Tensor:
"""Elu feature map operation from: https://arxiv.org/pdf/2006.16236.pdf"""
return F.elu(x) + 1
class LinearAttention(nn.Module):
"""
Linear attention operation from: https://arxiv.org/pdf/2006.16236.pdf
Canonical implementation reference: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py
"""
def __init__(self, eps: float = 1e-6, feature_map_fn: Callable[[Tensor], Tensor] = elu_feature_map) -> None:
super().__init__()
self.eps = eps
self.feature_map_fn = feature_map_fn
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
q_mask: Optional[Tensor] = None,
kv_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
queries (torch.Tensor): [N, S1, H, D]
keys (torch.Tensor): [N, S2, H, D]
values (torch.Tensor): [N, S2, H, D]
q_mask (torch.Tensor): [N, S1] (optional)
kv_mask (torch.Tensor): [N, S2] (optional)
Returns:
queried_values (torch.Tensor): [N, S1, H, D]
"""
queries = self.feature_map_fn(queries)
keys = self.feature_map_fn(keys)
if q_mask is not None:
queries = queries * q_mask[:, :, None, None]
if kv_mask is not None:
keys = keys * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]
# mitigates fp16 overflows
values_length = values.shape[1]
values = values / values_length
kv = torch.einsum("NSHD, NSHV -> NHDV", keys, values)
z = 1 / (torch.einsum("NLHD, NHD -> NLH", queries, keys.sum(dim=1)) + self.eps)
# rescale at the end to account for fp16 mitigation
queried_values = torch.einsum("NLHD, NHDV, NLH -> NLHV", queries, kv, z) * values_length
return queried_values
class SoftmaxAttention(nn.Module):
"""
A simple softmax attention operation
LoFTR implementation reference: https://github.com/zju3dv/LoFTR/blob/2122156015b61fbb650e28b58a958e4d632b1058/src/loftr/loftr_module/linear_attention.py
"""
def __init__(self, dropout: float = 0.0) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout) if dropout else nn.Identity()
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
q_mask: Optional[Tensor] = None,
kv_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Computes classical softmax full-attention between all queries and keys.
Args:
queries (torch.Tensor): [N, S1, H, D]
keys (torch.Tensor): [N, S2, H, D]
values (torch.Tensor): [N, S2, H, D]
q_mask (torch.Tensor): [N, S1] (optional)
kv_mask (torch.Tensor): [N, S2] (optional)
Returns:
queried_values: [N, S1, H, D]
"""
scale_factor = 1.0 / queries.shape[3] ** 0.5 # irsqrt(D) scaling
queries = queries * scale_factor
qk = torch.einsum("NLHD, NSHD -> NLSH", queries, keys)
if kv_mask is not None and q_mask is not None:
qk.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float("-inf"))
attention = torch.softmax(qk, dim=2)
attention = self.dropout(attention)
queried_values = torch.einsum("NLSH, NSHD -> NLHD", attention, values)
return queried_values
class PositionalEncodingSine(nn.Module):
"""
Sinusoidal positional encodings
Using the scaling term from https://github.com/megvii-research/CREStereo/blob/master/nets/attention/position_encoding.py
Reference implementation from https://github.com/facebookresearch/detr/blob/8a144f83a287f4d3fece4acdf073f387c5af387d/models/position_encoding.py#L28-L48
"""
def __init__(self, dim_model: int, max_size: int = 256) -> None:
super().__init__()
self.dim_model = dim_model
self.max_size = max_size
# pre-registered for memory efficiency during forward pass
pe = self._make_pe_of_size(self.max_size)
self.register_buffer("pe", pe)
def _make_pe_of_size(self, size: int) -> Tensor:
pe = torch.zeros((self.dim_model, *(size, size)), dtype=torch.float32)
y_positions = torch.ones((size, size)).cumsum(0).float().unsqueeze(0)
x_positions = torch.ones((size, size)).cumsum(1).float().unsqueeze(0)
div_term = torch.exp(torch.arange(0.0, self.dim_model // 2, 2) * (-math.log(10000.0) / self.dim_model // 2))
div_term = div_term[:, None, None]
pe[0::4, :, :] = torch.sin(x_positions * div_term)
pe[1::4, :, :] = torch.cos(x_positions * div_term)
pe[2::4, :, :] = torch.sin(y_positions * div_term)
pe[3::4, :, :] = torch.cos(y_positions * div_term)
pe = pe.unsqueeze(0)
return pe
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: [B, C, H, W]
Returns:
x: [B, C, H, W]
"""
torch._assert(
len(x.shape) == 4,
f"PositionalEncodingSine requires a 4-D dimensional input. Provided tensor is of shape {x.shape}",
)
B, C, H, W = x.shape
return x + self.pe[:, :, :H, :W] # type: ignore
class LocalFeatureEncoderLayer(nn.Module):
"""
LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf
Canonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
"""
def __init__(
self,
*,
dim_model: int,
num_heads: int,
attention_module: Callable[..., nn.Module] = LinearAttention,
) -> None:
super().__init__()
self.attention_op = attention_module()
if not isinstance(self.attention_op, (LinearAttention, SoftmaxAttention)):
raise ValueError(
f"attention_module must be an instance of LinearAttention or SoftmaxAttention. Got {type(self.attention_op)}"
)
self.dim_head = dim_model // num_heads
self.num_heads = num_heads
# multi-head attention
self.query_proj = nn.Linear(dim_model, dim_model, bias=False)
self.key_proj = nn.Linear(dim_model, dim_model, bias=False)
self.value_proj = nn.Linear(dim_model, dim_model, bias=False)
self.merge = nn.Linear(dim_model, dim_model, bias=False)
# feed forward network
self.ffn = nn.Sequential(
nn.Linear(dim_model * 2, dim_model * 2, bias=False),
nn.ReLU(),
nn.Linear(dim_model * 2, dim_model, bias=False),
)
# norm layers
self.attention_norm = nn.LayerNorm(dim_model)
self.ffn_norm = nn.LayerNorm(dim_model)
def forward(
self, x: Tensor, source: Tensor, x_mask: Optional[Tensor] = None, source_mask: Optional[Tensor] = None
) -> Tensor:
"""
Args:
x (torch.Tensor): [B, S1, D]
source (torch.Tensor): [B, S2, D]
x_mask (torch.Tensor): [B, S1] (optional)
source_mask (torch.Tensor): [B, S2] (optional)
"""
B, S, D = x.shape
queries, keys, values = x, source, source
queries = self.query_proj(queries).reshape(B, S, self.num_heads, self.dim_head)
keys = self.key_proj(keys).reshape(B, S, self.num_heads, self.dim_head)
values = self.value_proj(values).reshape(B, S, self.num_heads, self.dim_head)
# attention operation
message = self.attention_op(queries, keys, values, x_mask, source_mask)
# concatenating attention heads together before passing through projection layer
message = self.merge(message.reshape(B, S, D))
message = self.attention_norm(message)
# ffn operation
message = self.ffn(torch.cat([x, message], dim=2))
message = self.ffn_norm(message)
return x + message
class LocalFeatureTransformer(nn.Module):
"""
LoFTR transformer module from: https://arxiv.org/pdf/2104.00680.pdf
Canonical implementations at: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
"""
def __init__(
self,
*,
dim_model: int,
num_heads: int,
attention_directions: List[str],
attention_module: Callable[..., nn.Module] = LinearAttention,
) -> None:
super(LocalFeatureTransformer, self).__init__()
self.attention_module = attention_module
self.attention_directions = attention_directions
for direction in attention_directions:
if direction not in ["self", "cross"]:
raise ValueError(
f"Attention direction {direction} unsupported. LocalFeatureTransformer accepts only ``attention_type`` in ``[self, cross]``."
)
self.layers = nn.ModuleList(
[
LocalFeatureEncoderLayer(dim_model=dim_model, num_heads=num_heads, attention_module=attention_module)
for _ in attention_directions
]
)
def forward(
self,
left_features: Tensor,
right_features: Tensor,
left_mask: Optional[Tensor] = None,
right_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Args:
left_features (torch.Tensor): [N, S1, D]
right_features (torch.Tensor): [N, S2, D]
left_mask (torch.Tensor): [N, S1] (optional)
right_mask (torch.Tensor): [N, S2] (optional)
Returns:
left_features (torch.Tensor): [N, S1, D]
right_features (torch.Tensor): [N, S2, D]
"""
torch._assert(
left_features.shape[2] == right_features.shape[2],
f"left_features and right_features should have the same embedding dimensions. left_features: {left_features.shape[2]} right_features: {right_features.shape[2]}",
)
for idx, layer in enumerate(self.layers):
attention_direction = self.attention_directions[idx]
if attention_direction == "self":
left_features = layer(left_features, left_features, left_mask, left_mask)
right_features = layer(right_features, right_features, right_mask, right_mask)
elif attention_direction == "cross":
left_features = layer(left_features, right_features, left_mask, right_mask)
right_features = layer(right_features, left_features, right_mask, left_mask)
return left_features, right_features
class PyramidDownsample(nn.Module):
"""
A simple wrapper that return and Avg Pool feature pyramid based on the provided scales.
Implicitly returns the input as well.
"""
def __init__(self, factors: Iterable[int]) -> None:
super().__init__()
self.factors = factors
def forward(self, x: torch.Tensor) -> List[Tensor]:
results = [x]
for factor in self.factors:
results.append(F.avg_pool2d(x, kernel_size=factor, stride=factor))
return results
class CREStereo(nn.Module):
"""
Implements CREStereo from the `"Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation" <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_ paper.
Args:
feature_encoder (raft.FeatureEncoder): Raft-like Feature Encoder module extract low-level features from inputs.
update_block (raft.UpdateBlock): Raft-like Update Block which recursively refines a flow-map.
flow_head (raft.FlowHead): Raft-like Flow Head which predics a flow-map from some inputs.
self_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs self attention on the two feature maps.
cross_attn_block (LocalFeatureTransformer): A Local Feature Transformer that performs cross attention between the two feature maps
used in the Adaptive Group Correlation module.
feature_downsample_rates (List[int]): The downsample rates used to build a feature pyramid from the outputs of the `feature_encoder`. Default: [2, 4]
correlation_groups (int): In how many groups should the features be split when computer per-pixel correlation. Defaults 4.
search_window_1d (Tuple[int, int]): The alternate search window size in the x and y directions for the 1D case. Defaults to (1, 9).
search_dilate_1d (Tuple[int, int]): The dilation used in the `search_window_1d` when selecting pixels. Similar to `nn.Conv2d` dilate. Defaults to (1, 1).
search_window_2d (Tuple[int, int]): The alternate search window size in the x and y directions for the 2D case. Defaults to (3, 3).
search_dilate_2d (Tuple[int, int]): The dilation used in the `search_window_2d` when selecting pixels. Similar to `nn.Conv2d` dilate. Defaults to (1, 1).
"""
def __init__(
self,
*,
feature_encoder: raft.FeatureEncoder,
update_block: raft.UpdateBlock,
flow_head: raft.FlowHead,
self_attn_block: LocalFeatureTransformer,
cross_attn_block: LocalFeatureTransformer,
feature_downsample_rates: Tuple[int, ...] = (2, 4),
correlation_groups: int = 4,
search_window_1d: Tuple[int, int] = (1, 9),
search_dilate_1d: Tuple[int, int] = (1, 1),
search_window_2d: Tuple[int, int] = (3, 3),
search_dilate_2d: Tuple[int, int] = (1, 1),
) -> None:
super().__init__()
self.output_channels = 2
self.feature_encoder = feature_encoder
self.update_block = update_block
self.flow_head = flow_head
self.self_attn_block = self_attn_block
# average pooling for the feature encoder outputs
self.downsampling_pyramid = PyramidDownsample(feature_downsample_rates)
self.downsampling_factors: List[int] = [feature_encoder.downsample_factor]
base_downsample_factor: int = self.downsampling_factors[0]
for rate in feature_downsample_rates:
self.downsampling_factors.append(base_downsample_factor * rate)
# output resolution tracking
self.resolutions: List[str] = [f"1 / {factor}" for factor in self.downsampling_factors]
self.search_pixels = int(np.prod(search_window_1d))
# flow convex upsampling mask predictor
self.mask_predictor = ConvexMaskPredictor(
in_channels=feature_encoder.output_dim // 2,
hidden_size=feature_encoder.output_dim,
upsample_factor=feature_encoder.downsample_factor,
multiplier=0.25,
)
# offsets modules for offsetted feature selection
self.offset_convs = nn.ModuleDict()
self.correlation_layers = nn.ModuleDict()
offset_conv_layer = partial(
Conv2dNormActivation,
in_channels=feature_encoder.output_dim,
out_channels=self.search_pixels * 2,
norm_layer=None,
activation_layer=None,
)
# populate the dicts in top to bottom order
# useful for iterating through torch.jit.script module given the network forward pass
#
# Ignore the largest resolution. We handle that separately due to torch.jit.script
# not being able to access to runtime generated keys in ModuleDicts.
# This way, we can keep a generic way of processing all pyramid levels but except
# the final one
iterative_correlation_layer = partial(
IterativeCorrelationLayer,
groups=correlation_groups,
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
attention_offset_correlation_layer = partial(
AttentionOffsetCorrelationLayer,
groups=correlation_groups,
search_window_1d=search_window_1d,
search_dilate_1d=search_dilate_1d,
search_window_2d=search_window_2d,
search_dilate_2d=search_dilate_2d,
)
for idx, resolution in enumerate(reversed(self.resolutions[1:])):
# the largest resolution does use offset convolutions for sampling grid coords
offset_conv = None if idx == len(self.resolutions) - 1 else offset_conv_layer()
if offset_conv:
self.offset_convs[resolution] = offset_conv
# only the lowest resolution uses the cross attention module when computing correlation scores
attention_module = cross_attn_block if idx == 0 else None
self.correlation_layers[resolution] = AdaptiveGroupCorrelationLayer(
iterative_correlation_layer=iterative_correlation_layer(),
attention_offset_correlation_layer=attention_offset_correlation_layer(
attention_module=attention_module
),
)
# correlation layer for the largest resolution
self.max_res_correlation_layer = AdaptiveGroupCorrelationLayer(
iterative_correlation_layer=iterative_correlation_layer(),
attention_offset_correlation_layer=attention_offset_correlation_layer(),
)
# simple 2D Postional Encodings
self.positional_encodings = PositionalEncodingSine(feature_encoder.output_dim)
def _get_window_type(self, iteration: int) -> str:
return "1d" if iteration % 2 == 0 else "2d"
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 10
) -> List[Tensor]:
features = torch.cat([left_image, right_image], dim=0)
features = self.feature_encoder(features)
left_features, right_features = features.chunk(2, dim=0)
# update block network state and input context are derived from the left feature map
net, ctx = left_features.chunk(2, dim=1)
net = torch.tanh(net)
ctx = torch.relu(ctx)
# will output lists of tensor.
l_pyramid = self.downsampling_pyramid(left_features)
r_pyramid = self.downsampling_pyramid(right_features)
net_pyramid = self.downsampling_pyramid(net)
ctx_pyramid = self.downsampling_pyramid(ctx)
# we store in reversed order because we process the pyramid from top to bottom
l_pyramid = {res: l_pyramid[idx] for idx, res in enumerate(self.resolutions)}
r_pyramid = {res: r_pyramid[idx] for idx, res in enumerate(self.resolutions)}
net_pyramid = {res: net_pyramid[idx] for idx, res in enumerate(self.resolutions)}
ctx_pyramid = {res: ctx_pyramid[idx] for idx, res in enumerate(self.resolutions)}
# offsets for sampling pixel candidates in the correlation ops
offsets: Dict[str, Tensor] = {}
for resolution, offset_conv in self.offset_convs.items():
feature_map = l_pyramid[resolution]
offset = offset_conv(feature_map)
offsets[resolution] = (torch.sigmoid(offset) - 0.5) * 2.0
# the smallest resolution is prepared for passing through self attention
min_res = self.resolutions[-1]
max_res = self.resolutions[0]
B, C, MIN_H, MIN_W = l_pyramid[min_res].shape
# add positional encodings
l_pyramid[min_res] = self.positional_encodings(l_pyramid[min_res])
r_pyramid[min_res] = self.positional_encodings(r_pyramid[min_res])
# reshaping for transformer
l_pyramid[min_res] = l_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C)
r_pyramid[min_res] = r_pyramid[min_res].permute(0, 2, 3, 1).reshape(B, MIN_H * MIN_W, C)
# perform self attention
l_pyramid[min_res], r_pyramid[min_res] = self.self_attn_block(l_pyramid[min_res], r_pyramid[min_res])
# now we need to reshape back into [B, C, H, W] format
l_pyramid[min_res] = l_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2)
r_pyramid[min_res] = r_pyramid[min_res].reshape(B, MIN_H, MIN_W, C).permute(0, 3, 1, 2)
predictions: List[Tensor] = []
flow_estimates: Dict[str, Tensor] = {}
# we added this because of torch.script.jit
# also, the predicition prior is always going to have the
# spatial size of the features outputted by the feature encoder
flow_pred_prior: Tensor = torch.empty(
size=(B, 2, left_features.shape[2], left_features.shape[3]),
dtype=l_pyramid[max_res].dtype,
device=l_pyramid[max_res].device,
)
if flow_init is not None:
scale = l_pyramid[max_res].shape[2] / flow_init.shape[2]
# in CREStereo implementation they multiply with -scale instead of scale
# this can be either a downsample or an upsample based on the cascaded inference
# configuration
# we use a -scale because the flow used inside the network is a negative flow
# from the right to the left, so we flip the flow direction
flow_estimates[max_res] = -scale * F.interpolate(
input=flow_init,
size=l_pyramid[max_res].shape[2:],
mode="bilinear",
align_corners=True,
)
# when not provided with a flow prior, we construct one using the lower resolution maps
else:
# initialize a zero flow with the smallest resolution
flow = torch.zeros(size=(B, 2, MIN_H, MIN_W), device=left_features.device, dtype=left_features.dtype)
# flows from coarse resolutions are refined similarly
# we always need to fetch the next pyramid feature map as well
# when updating coarse resolutions, therefore we create a reversed
# view which has its order synced with the ModuleDict keys iterator
coarse_resolutions: List[str] = self.resolutions[::-1] # using slicing because of torch.jit.script
fine_grained_resolution = max_res
# set the coarsest flow to the zero flow
flow_estimates[coarse_resolutions[0]] = flow
# the correlation layer ModuleDict will contain layers ordered from coarse to fine resolution
# i.e ["1 / 16", "1 / 8", "1 / 4"]
# the correlation layer ModuleDict has layers for all the resolutions except the fine one
# i.e {"1 / 16": Module, "1 / 8": Module}
# for these resolution we perform only half of the number of refinement iterations
for idx, (resolution, correlation_layer) in enumerate(self.correlation_layers.items()):
# compute the scale difference between the first pyramid scale and the current pyramid scale
scale_to_base = l_pyramid[fine_grained_resolution].shape[2] // l_pyramid[resolution].shape[2]
for it in range(num_iters // 2):
# set whether we want to search on (X, Y) axes for correlation or just on X axis
window_type = self._get_window_type(it)
# we consider this a prior, therefore we do not want to back-propagate through it
flow_estimates[resolution] = flow_estimates[resolution].detach()
correlations = correlation_layer(
l_pyramid[resolution], # left
r_pyramid[resolution], # right
flow_estimates[resolution],
offsets[resolution],
window_type,
)
# update the recurrent network state and the flow deltas
net_pyramid[resolution], delta_flow = self.update_block(
net_pyramid[resolution], ctx_pyramid[resolution], correlations, flow_estimates[resolution]
)
# the convex upsampling weights are computed w.r.t.
# the recurrent update state
up_mask = self.mask_predictor(net_pyramid[resolution])
flow_estimates[resolution] = flow_estimates[resolution] + delta_flow
# convex upsampling with the initial feature encoder downsampling rate
flow_pred_prior = upsample_flow(
flow_estimates[resolution], up_mask, factor=self.downsampling_factors[0]
)
# we then bilinear upsample to the final resolution
# we use a factor that's equivalent to the difference between
# the current downsample resolution and the base downsample resolution
#
# i.e. if a 1 / 16 flow is upsampled by 4 (base downsampling) we get a 1 / 4 flow.
# therefore we have to further upscale it by the difference between
# the current level 1 / 16 and the base level 1 / 4.
#
# we use a -scale because the flow used inside the network is a negative flow
# from the right to the left, so we flip the flow direction in order to get the
# left to right flow
flow_pred = -upsample_flow(flow_pred_prior, None, factor=scale_to_base)
predictions.append(flow_pred)
# when constructing the next resolution prior, we resample w.r.t
# to the scale of the next level in the pyramid
next_resolution = coarse_resolutions[idx + 1]
scale_to_next = l_pyramid[next_resolution].shape[2] / flow_pred_prior.shape[2]
# we use the flow_up_prior because this is a more accurate estimation of the true flow
# due to the convex upsample, which resembles a learned super-resolution module.
# this is not necessarily an upsample, it can be a downsample, based on the provided configuration
flow_estimates[next_resolution] = -scale_to_next * F.interpolate(
input=flow_pred_prior,
size=l_pyramid[next_resolution].shape[2:],
mode="bilinear",
align_corners=True,
)
# finally we will be doing a full pass through the fine-grained resolution
# this coincides with the maximum resolution
# we keep a separate loop here in order to avoid python control flow
# to decide how many iterations should we do based on the current resolution
# furthermore, if provided with an initial flow, there is no need to generate
# a prior estimate when moving into the final refinement stage
for it in range(num_iters):
search_window_type = self._get_window_type(it)
flow_estimates[max_res] = flow_estimates[max_res].detach()
# we run the fine-grained resolution correlations in iterative mode
# this means that we are using the fixed window pixel selections
# instead of the deformed ones as with the previous steps
correlations = self.max_res_correlation_layer(
l_pyramid[max_res],
r_pyramid[max_res],
flow_estimates[max_res],
extra_offset=None,
window_type=search_window_type,
iter_mode=True,
)
net_pyramid[max_res], delta_flow = self.update_block(
net_pyramid[max_res], ctx_pyramid[max_res], correlations, flow_estimates[max_res]
)
up_mask = self.mask_predictor(net_pyramid[max_res])
flow_estimates[max_res] = flow_estimates[max_res] + delta_flow
# at the final resolution we simply do a convex upsample using the base downsample rate
flow_pred = -upsample_flow(flow_estimates[max_res], up_mask, factor=self.downsampling_factors[0])
predictions.append(flow_pred)
return predictions
def _crestereo(
*,
weights: Optional[WeightsEnum],
progress: bool,
# Feature Encoder
feature_encoder_layers: Tuple[int, int, int, int, int],
feature_encoder_strides: Tuple[int, int, int, int],
feature_encoder_block: Callable[..., nn.Module],
feature_encoder_norm_layer: Callable[..., nn.Module],
# Average Pooling Pyramid
feature_downsample_rates: Tuple[int, ...],
# Adaptive Correlation Layer
corr_groups: int,
corr_search_window_2d: Tuple[int, int],
corr_search_dilate_2d: Tuple[int, int],
corr_search_window_1d: Tuple[int, int],
corr_search_dilate_1d: Tuple[int, int],
# Flow head
flow_head_hidden_size: int,
# Recurrent block
recurrent_block_hidden_state_size: int,
recurrent_block_kernel_size: Tuple[Tuple[int, int], Tuple[int, int]],
recurrent_block_padding: Tuple[Tuple[int, int], Tuple[int, int]],
# Motion Encoder
motion_encoder_corr_layers: Tuple[int, int],
motion_encoder_flow_layers: Tuple[int, int],
motion_encoder_out_channels: int,
# Transformer Blocks
num_attention_heads: int,
num_self_attention_layers: int,
num_cross_attention_layers: int,
self_attention_module: Callable[..., nn.Module],
cross_attention_module: Callable[..., nn.Module],
**kwargs,
) -> CREStereo:
feature_encoder = kwargs.pop("feature_encoder", None) or raft.FeatureEncoder(
block=feature_encoder_block,
layers=feature_encoder_layers,
strides=feature_encoder_strides,
norm_layer=feature_encoder_norm_layer,
)
if feature_encoder.output_dim % corr_groups != 0:
raise ValueError(
f"Final ``feature_encoder_layers`` size should be divisible by ``corr_groups`` argument."
f"Feature encoder output size : {feature_encoder.output_dim}, Correlation groups: {corr_groups}."
)
motion_encoder = kwargs.pop("motion_encoder", None) or raft.MotionEncoder(
in_channels_corr=corr_groups * int(np.prod(corr_search_window_1d)),
corr_layers=motion_encoder_corr_layers,
flow_layers=motion_encoder_flow_layers,
out_channels=motion_encoder_out_channels,
)
out_channels_context = feature_encoder_layers[-1] - recurrent_block_hidden_state_size
recurrent_block = kwargs.pop("recurrent_block", None) or raft.RecurrentBlock(
input_size=motion_encoder.out_channels + out_channels_context,
hidden_size=recurrent_block_hidden_state_size,
kernel_size=recurrent_block_kernel_size,
padding=recurrent_block_padding,
)
flow_head = kwargs.pop("flow_head", None) or raft.FlowHead(
in_channels=out_channels_context, hidden_size=flow_head_hidden_size
)
update_block = raft.UpdateBlock(motion_encoder=motion_encoder, recurrent_block=recurrent_block, flow_head=flow_head)
self_attention_module = kwargs.pop("self_attention_module", None) or LinearAttention
self_attn_block = LocalFeatureTransformer(
dim_model=feature_encoder.output_dim,
num_heads=num_attention_heads,
attention_directions=["self"] * num_self_attention_layers,
attention_module=self_attention_module,
)
cross_attention_module = kwargs.pop("cross_attention_module", None) or LinearAttention
cross_attn_block = LocalFeatureTransformer(
dim_model=feature_encoder.output_dim,
num_heads=num_attention_heads,
attention_directions=["cross"] * num_cross_attention_layers,
attention_module=cross_attention_module,
)
model = CREStereo(
feature_encoder=feature_encoder,
update_block=update_block,
flow_head=flow_head,
self_attn_block=self_attn_block,
cross_attn_block=cross_attn_block,
feature_downsample_rates=feature_downsample_rates,
correlation_groups=corr_groups,
search_window_1d=corr_search_window_1d,
search_window_2d=corr_search_window_2d,
search_dilate_1d=corr_search_dilate_1d,
search_dilate_2d=corr_search_dilate_2d,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
_COMMON_META = {
"resize_size": (384, 512),
}
class CREStereo_Base_Weights(WeightsEnum):
"""The metrics reported here are as follows.
``mae`` is the "mean-average-error" and indicates how far (in pixels) the
predicted disparity is from its true value (equivalent to ``epe``). This is averaged over all pixels
of all images. ``1px``, ``3px``, ``5px`` and indicate the percentage of pixels that have a lower
error than that of the ground truth. ``relepe`` is the "relative-end-point-error" and is the
average ``epe`` divided by the average ground truth disparity. ``fl-all`` corresponds to the average of pixels whose epe
is either <3px, or whom's ``relepe`` is lower than 0.05 (therefore higher is better).
"""
MEGVII_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-756c8b0f.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/megvii-research/CREStereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement iterations
2: {
"mae": 1.704,
"rmse": 3.738,
"1px": 0.738,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.157,
"fl-all": 76.464,
},
5: {
"mae": 0.956,
"rmse": 2.963,
"1px": 0.88,
"3px": 0.948,
"5px": 0.965,
"relepe": 0.124,
"fl-all": 88.186,
},
10: {
"mae": 0.792,
"rmse": 2.765,
"1px": 0.905,
"3px": 0.958,
"5px": 0.97,
"relepe": 0.114,
"fl-all": 90.429,
},
20: {
"mae": 0.749,
"rmse": 2.706,
"1px": 0.907,
"3px": 0.961,
"5px": 0.972,
"relepe": 0.113,
"fl-all": 90.807,
},
},
2: {
2: {
"mae": 1.702,
"rmse": 3.784,
"1px": 0.784,
"3px": 0.894,
"5px": 0.924,
"relepe": 0.172,
"fl-all": 80.313,
},
5: {
"mae": 0.932,
"rmse": 2.907,
"1px": 0.877,
"3px": 0.944,
"5px": 0.963,
"relepe": 0.125,
"fl-all": 87.979,
},
10: {
"mae": 0.773,
"rmse": 2.768,
"1px": 0.901,
"3px": 0.958,
"5px": 0.972,
"relepe": 0.117,
"fl-all": 90.43,
},
20: {
"mae": 0.854,
"rmse": 2.971,
"1px": 0.9,
"3px": 0.957,
"5px": 0.97,
"relepe": 0.122,
"fl-all": 90.269,
},
},
},
}
},
"_docs": """These weights were ported from the original paper. They
are trained on a dataset mixture of the author's choice.""",
},
)
CRESTEREO_ETH_MBL_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-8f0e0e9a.pth",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is the number of refinement iterations
2: {
"mae": 2.363,
"rmse": 4.352,
"1px": 0.611,
"3px": 0.828,
"5px": 0.891,
"relepe": 0.176,
"fl-all": 64.511,
},
5: {
"mae": 1.618,
"rmse": 3.71,
"1px": 0.761,
"3px": 0.879,
"5px": 0.918,
"relepe": 0.154,
"fl-all": 77.128,
},
10: {
"mae": 1.416,
"rmse": 3.53,
"1px": 0.777,
"3px": 0.896,
"5px": 0.933,
"relepe": 0.148,
"fl-all": 78.388,
},
20: {
"mae": 1.448,
"rmse": 3.583,
"1px": 0.771,
"3px": 0.893,
"5px": 0.931,
"relepe": 0.145,
"fl-all": 77.7,
},
},
2: {
2: {
"mae": 1.972,
"rmse": 4.125,
"1px": 0.73,
"3px": 0.865,
"5px": 0.908,
"relepe": 0.169,
"fl-all": 74.396,
},
5: {
"mae": 1.403,
"rmse": 3.448,
"1px": 0.793,
"3px": 0.905,
"5px": 0.937,
"relepe": 0.151,
"fl-all": 80.186,
},
10: {
"mae": 1.312,
"rmse": 3.368,
"1px": 0.799,
"3px": 0.912,
"5px": 0.943,
"relepe": 0.148,
"fl-all": 80.379,
},
20: {
"mae": 1.376,
"rmse": 3.542,
"1px": 0.796,
"3px": 0.91,
"5px": 0.942,
"relepe": 0.149,
"fl-all": 80.054,
},
},
},
}
},
"_docs": """These weights were trained from scratch on
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo`.""",
},
)
CRESTEREO_FINETUNE_MULTI_V1 = Weights(
# Weights ported from https://github.com/megvii-research/CREStereo
url="https://download.pytorch.org/models/crestereo-697c38f4.pth ",
transforms=StereoMatching,
meta={
**_COMMON_META,
"num_params": 5432948,
"recipe": "https://github.com/pytorch/vision/tree/main/references/depth/stereo",
"_metrics": {
"Middlebury2014-train": {
# metrics for 10 refinement iterations and 1 cascade
"mae": 1.038,
"rmse": 3.108,
"1px": 0.852,
"3px": 0.942,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.522,
"_detailed": {
# 1 is the number of cascades
1: {
# 2 is number of refininement iterations
2: {
"mae": 1.85,
"rmse": 3.797,
"1px": 0.673,
"3px": 0.862,
"5px": 0.917,
"relepe": 0.171,
"fl-all": 69.736,
},
5: {
"mae": 1.111,
"rmse": 3.166,
"1px": 0.838,
"3px": 0.93,
"5px": 0.957,
"relepe": 0.134,
"fl-all": 84.596,
},
10: {
"mae": 1.02,
"rmse": 3.073,
"1px": 0.854,
"3px": 0.938,
"5px": 0.96,
"relepe": 0.129,
"fl-all": 86.042,
},
20: {
"mae": 0.993,
"rmse": 3.059,
"1px": 0.855,
"3px": 0.942,
"5px": 0.967,
"relepe": 0.126,
"fl-all": 85.784,
},
},
2: {
2: {
"mae": 1.667,
"rmse": 3.867,
"1px": 0.78,
"3px": 0.891,
"5px": 0.922,
"relepe": 0.165,
"fl-all": 78.89,
},
5: {
"mae": 1.158,
"rmse": 3.278,
"1px": 0.843,
"3px": 0.926,
"5px": 0.955,
"relepe": 0.135,
"fl-all": 84.556,
},
10: {
"mae": 1.046,
"rmse": 3.13,
"1px": 0.85,
"3px": 0.934,
"5px": 0.96,
"relepe": 0.13,
"fl-all": 85.464,
},
20: {
"mae": 1.021,
"rmse": 3.102,
"1px": 0.85,
"3px": 0.935,
"5px": 0.963,
"relepe": 0.129,
"fl-all": 85.417,
},
},
},
},
},
"_docs": """These weights were finetuned on a mixture of
:class:`~torchvision.datasets._stereo_matching.CREStereo` +
:class:`~torchvision.datasets._stereo_matching.Middlebury2014Stereo` +
:class:`~torchvision.datasets._stereo_matching.ETH3DStereo` +
:class:`~torchvision.datasets._stereo_matching.InStereo2k` +
:class:`~torchvision.datasets._stereo_matching.CarlaStereo` +
:class:`~torchvision.datasets._stereo_matching.SintelStereo` +
:class:`~torchvision.datasets._stereo_matching.FallingThingsStereo` +
.""",
},
)
DEFAULT = MEGVII_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", CREStereo_Base_Weights.MEGVII_V1))
def crestereo_base(*, weights: Optional[CREStereo_Base_Weights] = None, progress=True, **kwargs) -> CREStereo:
"""CREStereo model from
`Practical Stereo Matching via Cascaded Recurrent Network
With Adaptive Correlation <https://openaccess.thecvf.com/content/CVPR2022/papers/Li_Practical_Stereo_Matching_via_Cascaded_Recurrent_Network_With_Adaptive_Correlation_CVPR_2022_paper.pdf>`_.
Please see the example below for a tutorial on how to use this model.
Args:
weights(:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights`
below for more details, and possible values. By default, no
pre-trained weights are used.
progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.prototype.models.depth.stereo.raft_stereo.RaftStereo``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/crestereo.py>`_
for more details about this class.
.. autoclass:: torchvision.prototype.models.depth.stereo.CREStereo_Base_Weights
:members:
"""
weights = CREStereo_Base_Weights.verify(weights)
return _crestereo(
weights=weights,
progress=progress,
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_strides=(2, 1, 2, 1),
feature_encoder_block=partial(raft.ResidualBlock, always_project=True),
feature_encoder_norm_layer=nn.InstanceNorm2d,
# Average pooling pyramid
feature_downsample_rates=(2, 4),
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Transformer blocks
num_attention_heads=8,
num_self_attention_layers=1,
num_cross_attention_layers=1,
self_attention_module=LinearAttention,
cross_attention_module=LinearAttention,
# Adaptive Correlation layer
corr_groups=4,
corr_search_window_2d=(3, 3),
corr_search_dilate_2d=(1, 1),
corr_search_window_1d=(1, 9),
corr_search_dilate_1d=(1, 1),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment