"docs/vscode:/vscode.git/clone" did not exist on "ea131ebc77bd91ecf91be8669c8702293b2390a3"
Unverified Commit 48b1edff authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove prototype area for 0.19 (#8491)

parent f44f20cf
import io
import pathlib
from collections import namedtuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "pcam"
class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]):
def __init__(
self,
datapipe: IterDataPipe[Tuple[str, io.IOBase]],
key: Optional[str] = None, # Note: this key thing might be very specific to the PCAM dataset
) -> None:
self.datapipe = datapipe
self.key = key
def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
import h5py
for _, handle in self.datapipe:
try:
with h5py.File(handle) as data:
if self.key is not None:
data = data[self.key]
yield from data
finally:
handle.close()
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=["0", "1"])
@register_dataset(NAME)
class PCAM(Dataset):
# TODO write proper docstring
"""PCAM Dataset
homepage="https://github.com/basveeling/pcam"
"""
def __init__(
self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",))
_RESOURCES = {
"train": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_train_x.h5.gz",
gdrive_id="1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2",
sha256="d619e741468a7ab35c7e4a75e6821b7e7e6c9411705d45708f2a0efc8960656c",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_train_y.h5.gz",
gdrive_id="1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
sha256="b74126d2c01b20d3661f9b46765d29cf4e4fba6faba29c8e0d09d406331ab75a",
),
),
"test": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_test_x.h5.gz",
gdrive_id="1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
sha256="79174c2201ad521602a5888be8f36ee10875f37403dd3f2086caf2182ef87245",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_test_y.h5.gz",
gdrive_id="17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
sha256="0a522005fccc8bbd04c5a117bfaf81d8da2676f03a29d7499f71d0a0bd6068ef",
),
),
"val": (
_Resource( # Images
file_name="camelyonpatch_level_2_split_valid_x.h5.gz",
gdrive_id="1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
sha256="f82ee1670d027b4ec388048d9eabc2186b77c009655dae76d624c0ecb053ccb2",
),
_Resource( # Targets
file_name="camelyonpatch_level_2_split_valid_y.h5.gz",
gdrive_id="1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
sha256="ce1ae30f08feb468447971cfd0472e7becd0ad96d877c64120c72571439ae48c",
),
),
}
def _resources(self) -> List[OnlineResource]:
return [ # = [images resource, targets resource]
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
for file_name, gdrive_id, sha256 in self._RESOURCES[self._split]
]
def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]:
image, target = data # They're both numpy arrays at this point
return {
"image": Image(image.transpose(2, 0, 1)),
"label": Label(target.item(), categories=self._categories),
}
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
images_dp = PCAMH5Reader(images_dp, key="x")
targets_dp = PCAMH5Reader(targets_dp, key="y")
dp = Zipper(images_dp, targets_dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 262_144 if self._split == "train" else 32_768
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import pathlib
import re
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
read_mat,
)
from .._api import register_dataset, register_info
NAME = "sbd"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class SBD(Dataset):
"""
- **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html
- **dependencies**:
- <scipy `https://scipy.org`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval"))
self._categories = _info()["categories"]
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
resources = [
HttpResource(
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
)
]
if self._split == "train_noval":
resources.append(
HttpResource(
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
)
)
return resources # type: ignore[return-value]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
parent, grandparent, *_ = path.parents
if grandparent.name == "dataset":
if parent.name == "img":
return 0
elif parent.name == "cls":
return 1
if parent.name == "dataset" and self._split != "train_noval":
return 2
return None
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
anns = read_mat(ann_buffer, squeeze_me=True)["GTcls"]
return dict(
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=torch.as_tensor(
np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])
),
segmentation=torch.as_tensor(anns["Segmentation"].item()),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
if self._split == "train_noval":
archive_dp, split_dp = resource_dps
images_dp, anns_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
else:
archive_dp = resource_dps[0]
images_dp, anns_dp, split_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
buffer_size=INFINITE_BUFFER_SIZE,
drop_none=True,
)
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
dp = IterKeyZipper(
dp,
data_dp,
key_fn=getitem(*[0] * level, 1),
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 8_498,
"val": 2_857,
"train_noval": 5_623,
}[self._split]
def _generate_categories(self) -> Tuple[str, ...]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp)
dp = Mapper(dp, bytes.decode, input_col=1)
lines = tuple(zip(*iter(dp)))[1]
pattern = re.compile(r"\s*'(?P<category>\w+)';\s*%(?P<label>\d+)")
categories_and_labels = cast(
List[Tuple[str, ...]],
[
pattern.match(line).groups() # type: ignore[union-attr]
# the first and last line contain no information
for line in lines[1:-1]
],
)
categories_and_labels.sort(key=lambda category_and_label: int(category_and_label[1]))
categories, _ = zip(*categories_and_labels)
return categories
import pathlib
from typing import Any, Dict, List, Tuple, Union
import torch
from torchdata.datapipes.iter import CSVParser, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import OneHotLabel
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "semeion"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(i) for i in range(10)])
@register_dataset(NAME)
class SEMEION(Dataset):
"""Semeion dataset
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
"""
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
data = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
)
return [data]
def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]:
image_data, label_data = data[:256], data[256:-1]
return dict(
image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)),
label=OneHotLabel([int(label) for label in label_data], categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 1_593
AM General Hummer SUV 2000
Acura RL Sedan 2012
Acura TL Sedan 2012
Acura TL Type-S 2008
Acura TSX Sedan 2012
Acura Integra Type R 2001
Acura ZDX Hatchback 2012
Aston Martin V8 Vantage Convertible 2012
Aston Martin V8 Vantage Coupe 2012
Aston Martin Virage Convertible 2012
Aston Martin Virage Coupe 2012
Audi RS 4 Convertible 2008
Audi A5 Coupe 2012
Audi TTS Coupe 2012
Audi R8 Coupe 2012
Audi V8 Sedan 1994
Audi 100 Sedan 1994
Audi 100 Wagon 1994
Audi TT Hatchback 2011
Audi S6 Sedan 2011
Audi S5 Convertible 2012
Audi S5 Coupe 2012
Audi S4 Sedan 2012
Audi S4 Sedan 2007
Audi TT RS Coupe 2012
BMW ActiveHybrid 5 Sedan 2012
BMW 1 Series Convertible 2012
BMW 1 Series Coupe 2012
BMW 3 Series Sedan 2012
BMW 3 Series Wagon 2012
BMW 6 Series Convertible 2007
BMW X5 SUV 2007
BMW X6 SUV 2012
BMW M3 Coupe 2012
BMW M5 Sedan 2010
BMW M6 Convertible 2010
BMW X3 SUV 2012
BMW Z4 Convertible 2012
Bentley Continental Supersports Conv. Convertible 2012
Bentley Arnage Sedan 2009
Bentley Mulsanne Sedan 2011
Bentley Continental GT Coupe 2012
Bentley Continental GT Coupe 2007
Bentley Continental Flying Spur Sedan 2007
Bugatti Veyron 16.4 Convertible 2009
Bugatti Veyron 16.4 Coupe 2009
Buick Regal GS 2012
Buick Rainier SUV 2007
Buick Verano Sedan 2012
Buick Enclave SUV 2012
Cadillac CTS-V Sedan 2012
Cadillac SRX SUV 2012
Cadillac Escalade EXT Crew Cab 2007
Chevrolet Silverado 1500 Hybrid Crew Cab 2012
Chevrolet Corvette Convertible 2012
Chevrolet Corvette ZR1 2012
Chevrolet Corvette Ron Fellows Edition Z06 2007
Chevrolet Traverse SUV 2012
Chevrolet Camaro Convertible 2012
Chevrolet HHR SS 2010
Chevrolet Impala Sedan 2007
Chevrolet Tahoe Hybrid SUV 2012
Chevrolet Sonic Sedan 2012
Chevrolet Express Cargo Van 2007
Chevrolet Avalanche Crew Cab 2012
Chevrolet Cobalt SS 2010
Chevrolet Malibu Hybrid Sedan 2010
Chevrolet TrailBlazer SS 2009
Chevrolet Silverado 2500HD Regular Cab 2012
Chevrolet Silverado 1500 Classic Extended Cab 2007
Chevrolet Express Van 2007
Chevrolet Monte Carlo Coupe 2007
Chevrolet Malibu Sedan 2007
Chevrolet Silverado 1500 Extended Cab 2012
Chevrolet Silverado 1500 Regular Cab 2012
Chrysler Aspen SUV 2009
Chrysler Sebring Convertible 2010
Chrysler Town and Country Minivan 2012
Chrysler 300 SRT-8 2010
Chrysler Crossfire Convertible 2008
Chrysler PT Cruiser Convertible 2008
Daewoo Nubira Wagon 2002
Dodge Caliber Wagon 2012
Dodge Caliber Wagon 2007
Dodge Caravan Minivan 1997
Dodge Ram Pickup 3500 Crew Cab 2010
Dodge Ram Pickup 3500 Quad Cab 2009
Dodge Sprinter Cargo Van 2009
Dodge Journey SUV 2012
Dodge Dakota Crew Cab 2010
Dodge Dakota Club Cab 2007
Dodge Magnum Wagon 2008
Dodge Challenger SRT8 2011
Dodge Durango SUV 2012
Dodge Durango SUV 2007
Dodge Charger Sedan 2012
Dodge Charger SRT-8 2009
Eagle Talon Hatchback 1998
FIAT 500 Abarth 2012
FIAT 500 Convertible 2012
Ferrari FF Coupe 2012
Ferrari California Convertible 2012
Ferrari 458 Italia Convertible 2012
Ferrari 458 Italia Coupe 2012
Fisker Karma Sedan 2012
Ford F-450 Super Duty Crew Cab 2012
Ford Mustang Convertible 2007
Ford Freestar Minivan 2007
Ford Expedition EL SUV 2009
Ford Edge SUV 2012
Ford Ranger SuperCab 2011
Ford GT Coupe 2006
Ford F-150 Regular Cab 2012
Ford F-150 Regular Cab 2007
Ford Focus Sedan 2007
Ford E-Series Wagon Van 2012
Ford Fiesta Sedan 2012
GMC Terrain SUV 2012
GMC Savana Van 2012
GMC Yukon Hybrid SUV 2012
GMC Acadia SUV 2012
GMC Canyon Extended Cab 2012
Geo Metro Convertible 1993
HUMMER H3T Crew Cab 2010
HUMMER H2 SUT Crew Cab 2009
Honda Odyssey Minivan 2012
Honda Odyssey Minivan 2007
Honda Accord Coupe 2012
Honda Accord Sedan 2012
Hyundai Veloster Hatchback 2012
Hyundai Santa Fe SUV 2012
Hyundai Tucson SUV 2012
Hyundai Veracruz SUV 2012
Hyundai Sonata Hybrid Sedan 2012
Hyundai Elantra Sedan 2007
Hyundai Accent Sedan 2012
Hyundai Genesis Sedan 2012
Hyundai Sonata Sedan 2012
Hyundai Elantra Touring Hatchback 2012
Hyundai Azera Sedan 2012
Infiniti G Coupe IPL 2012
Infiniti QX56 SUV 2011
Isuzu Ascender SUV 2008
Jaguar XK XKR 2012
Jeep Patriot SUV 2012
Jeep Wrangler SUV 2012
Jeep Liberty SUV 2012
Jeep Grand Cherokee SUV 2012
Jeep Compass SUV 2012
Lamborghini Reventon Coupe 2008
Lamborghini Aventador Coupe 2012
Lamborghini Gallardo LP 570-4 Superleggera 2012
Lamborghini Diablo Coupe 2001
Land Rover Range Rover SUV 2012
Land Rover LR2 SUV 2012
Lincoln Town Car Sedan 2011
MINI Cooper Roadster Convertible 2012
Maybach Landaulet Convertible 2012
Mazda Tribute SUV 2011
McLaren MP4-12C Coupe 2012
Mercedes-Benz 300-Class Convertible 1993
Mercedes-Benz C-Class Sedan 2012
Mercedes-Benz SL-Class Coupe 2009
Mercedes-Benz E-Class Sedan 2012
Mercedes-Benz S-Class Sedan 2012
Mercedes-Benz Sprinter Van 2012
Mitsubishi Lancer Sedan 2012
Nissan Leaf Hatchback 2012
Nissan NV Passenger Van 2012
Nissan Juke Hatchback 2012
Nissan 240SX Coupe 1998
Plymouth Neon Coupe 1999
Porsche Panamera Sedan 2012
Ram C/V Cargo Van Minivan 2012
Rolls-Royce Phantom Drophead Coupe Convertible 2012
Rolls-Royce Ghost Sedan 2012
Rolls-Royce Phantom Sedan 2012
Scion xD Hatchback 2012
Spyker C8 Convertible 2009
Spyker C8 Coupe 2009
Suzuki Aerio Sedan 2007
Suzuki Kizashi Sedan 2012
Suzuki SX4 Hatchback 2012
Suzuki SX4 Sedan 2012
Tesla Model S Sedan 2012
Toyota Sequoia SUV 2012
Toyota Camry Sedan 2012
Toyota Corolla Sedan 2012
Toyota 4Runner SUV 2012
Volkswagen Golf Hatchback 2012
Volkswagen Golf Hatchback 1991
Volkswagen Beetle Hatchback 2012
Volvo C30 Hatchback 2012
Volvo 240 Sedan 1993
Volvo XC90 SUV 2007
smart fortwo Convertible 2012
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None:
self.datapipe = datapipe
def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]:
for _, file in self.datapipe:
data = read_mat(file, squeeze_me=True)
for ann in data["annotations"]:
yield tuple(ann) # type: ignore[misc]
NAME = "stanford-cars"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class StanfordCars(Dataset):
"""Stanford Cars dataset.
homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html",
dependencies=scipy
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_URL_ROOT = "https://ai.stanford.edu/~jkrause/"
_URLS = {
"train": f"{_URL_ROOT}car196/cars_train.tgz",
"test": f"{_URL_ROOT}car196/cars_test.tgz",
"cars_test_annos_withlabels": f"{_URL_ROOT}car196/cars_test_annos_withlabels.mat",
"car_devkit": f"{_URL_ROOT}cars/car_devkit.tgz",
}
_CHECKSUM = {
"train": "b97deb463af7d58b6bfaa18b2a4de9829f0f79e8ce663dfa9261bf7810e9accd",
"test": "bffea656d6f425cba3c91c6d83336e4c5f86c6cffd8975b0f375d3a10da8e243",
"cars_test_annos_withlabels": "790f75be8ea34eeded134cc559332baf23e30e91367e9ddca97d26ed9b895f05",
"car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288",
}
def _resources(self) -> List[OnlineResource]:
resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])]
if self._split == "train":
resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"]))
else:
resources.append(
HttpResource(
self._URLS["cars_test_annos_withlabels"], sha256=self._CHECKSUM["cars_test_annos_withlabels"]
)
)
return resources
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, int, int, str]]) -> Dict[str, Any]:
image, target = data
path, buffer = image
image = EncodedImage.from_file(buffer)
return dict(
path=path,
image=image,
label=Label(target[4] - 1, categories=self._categories),
bounding_boxes=BoundingBoxes(target[:4], format="xyxy", spatial_size=image.spatial_size),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, targets_dp = resource_dps
if self._split == "train":
targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat"))
targets_dp = StanfordCarsLabelReader(targets_dp)
dp = Zipper(images_dp, targets_dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self) -> List[str]:
resources = self._resources()
devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat"))
_, meta_file = next(iter(meta_dp))
return list(read_mat(meta_file, squeeze_me=True)["class_names"])
def __len__(self) -> int:
return 8_144 if self._split == "train" else 8_041
import pathlib
from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import IterDataPipe, Mapper, UnBatcher
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, read_mat
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "svhn"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class SVHN(Dataset):
"""SVHN Dataset.
homepage="http://ufldl.stanford.edu/housenumbers/",
dependencies = scipy
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",))
_CHECKSUMS = {
"train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8",
"test": "cdce80dfb2a2c4c6160906d0bd7c68ec5a99d7ca4831afa54f09182025b6a75b",
"extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3",
}
def _resources(self) -> List[OnlineResource]:
data = HttpResource(
f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat",
sha256=self._CHECKSUMS[self._split],
)
return [data]
def _read_images_and_labels(self, data: Tuple[str, BinaryIO]) -> List[Tuple[np.ndarray, np.ndarray]]:
_, buffer = data
content = read_mat(buffer)
return list(
zip(
content["X"].transpose((3, 0, 1, 2)),
content["y"].squeeze(),
)
)
def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any]:
image_array, label_array = data
return dict(
image=Image(image_array.transpose((2, 0, 1))),
label=Label(int(label_array) % 10, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Mapper(dp, self._read_images_and_labels)
dp = UnBatcher(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 73_257,
"test": 26_032,
"extra": 531_131,
}[self._split]
import pathlib
from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import Decompressor, IterDataPipe, LineReader, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "usps"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=[str(c) for c in range(10)])
@register_dataset(NAME)
class USPS(Dataset):
"""USPS Dataset
homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
_RESOURCES = {
"train": HttpResource(
f"{_URL}/usps.bz2", sha256="3771e9dd6ba685185f89867b6e249233dd74652389f263963b3b741e994b034f"
),
"test": HttpResource(
f"{_URL}/usps.t.bz2", sha256="a9c0164e797d60142a50604917f0baa604f326e9a689698763793fa5d12ffc4e"
),
}
def _resources(self) -> List[OnlineResource]:
return [USPS._RESOURCES[self._split]]
def _prepare_sample(self, line: str) -> Dict[str, Any]:
label, *values = line.strip().split(" ")
values = [float(value.split(":")[1]) for value in values]
pixels = torch.tensor(values).add_(1).div_(2)
return dict(
image=Image(pixels.reshape(16, 16)),
label=Label(int(label) - 1, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = Decompressor(resource_dps[0])
dp = LineReader(dp, decode=True, return_path=False)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 7_291 if self._split == "train" else 2_007
__background__
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import enum
import functools
import pathlib
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
from xml.etree import ElementTree
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.datasets import VOCDetection
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
NAME = "voc"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class VOC(Dataset):
"""
- **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2012",
task: str = "detection",
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
else:
self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._task = self._verify_str_arg(task, "task", ("detection", "segmentation"))
self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass"
self._split_folder = "Main" if task == "detection" else "Segmentation"
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
"2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"),
"2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"),
"2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"),
"2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"),
}
_TEST_ARCHIVES = {
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}
def _resources(self) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)
return [archive]
def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool:
path = pathlib.Path(data[0])
return name in path.parent.parts[-depth:]
class _Demux(enum.IntEnum):
SPLIT = 0
IMAGES = 1
ANNS = 2
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
if self._is_in_folder(data, name="ImageSets", depth=2):
return self._Demux.SPLIT
elif self._is_in_folder(data, name="JPEGImages"):
return self._Demux.IMAGES
elif self._is_in_folder(data, name=self._anns_folder):
return self._Demux.ANNS
else:
return None
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
buffer.close()
return ann
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
anns = self._parse_detection_ann(buffer)
instances = anns["object"]
return dict(
bounding_boxes=BoundingBoxes(
[
[int(instance["bndbox"][part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for instance in instances
],
format="xyxy",
spatial_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self._categories.index(instance["name"]) for instance in instances], categories=self._categories
),
)
def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
return dict(segmentation=EncodedImage.from_file(buffer))
def _prepare_sample(
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
return dict(
(self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
dp = IterKeyZipper(
dp,
data_dp,
key_fn=getitem(*[0] * level, 1),
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2007", "detection"): 2_501,
("train", "2007", "segmentation"): 209,
("train", "2008", "detection"): 2_111,
("train", "2008", "segmentation"): 511,
("train", "2009", "detection"): 3_473,
("train", "2009", "segmentation"): 749,
("train", "2010", "detection"): 4_998,
("train", "2010", "segmentation"): 964,
("train", "2011", "detection"): 5_717,
("train", "2011", "segmentation"): 1_112,
("train", "2012", "detection"): 5_717,
("train", "2012", "segmentation"): 1_464,
("val", "2007", "detection"): 2_510,
("val", "2007", "segmentation"): 213,
("val", "2008", "detection"): 2_221,
("val", "2008", "segmentation"): 512,
("val", "2009", "detection"): 3_581,
("val", "2009", "segmentation"): 750,
("val", "2010", "detection"): 5_105,
("val", "2010", "segmentation"): 964,
("val", "2011", "detection"): 5_823,
("val", "2011", "segmentation"): 1_111,
("val", "2012", "detection"): 5_823,
("val", "2012", "segmentation"): 1_449,
("trainval", "2007", "detection"): 5_011,
("trainval", "2007", "segmentation"): 422,
("trainval", "2008", "detection"): 4_332,
("trainval", "2008", "segmentation"): 1_023,
("trainval", "2009", "detection"): 7_054,
("trainval", "2009", "segmentation"): 1_499,
("trainval", "2010", "detection"): 10_103,
("trainval", "2010", "segmentation"): 1_928,
("trainval", "2011", "detection"): 11_540,
("trainval", "2011", "segmentation"): 2_223,
("trainval", "2012", "detection"): 11_540,
("trainval", "2012", "segmentation"): 2_913,
("test", "2007", "detection"): 4_952,
("test", "2007", "segmentation"): 210,
}[(self._split, self._year, self._task)]
def _filter_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == self._Demux.ANNS
def _generate_categories(self) -> List[str]:
self._task = "detection"
resources = self._resources()
archive_dp = resources[0].load(self._root)
dp = Filter(archive_dp, self._filter_anns)
dp = Mapper(dp, self._parse_detection_ann, input_col=1)
categories = sorted({instance["name"] for _, anns in dp for instance in anns["object"]})
# We add a background category to be used during segmentation
categories.insert(0, "__background__")
return categories
import functools
import os
import os.path
import pathlib
from typing import Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import FileLister, FileOpener, Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import EncodedData, EncodedImage
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
__all__ = ["from_data_folder", "from_image_folder"]
def _is_not_top_level_file(path: str, *, root: pathlib.Path) -> bool:
rel_path = pathlib.Path(path).relative_to(root)
return rel_path.is_dir() or rel_path.parent != pathlib.Path(".")
def _prepare_sample(
data: Tuple[str, BinaryIO],
*,
root: pathlib.Path,
categories: List[str],
) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).relative_to(root).parts[0]
return dict(
path=path,
data=EncodedData.from_file(buffer),
label=Label.from_category(category, categories=categories),
)
def from_data_folder(
root: Union[str, pathlib.Path],
*,
valid_extensions: Optional[Collection[str]] = None,
recursive: bool = True,
) -> Tuple[IterDataPipe, List[str]]:
root = pathlib.Path(root).expanduser().resolve()
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
dp = FileLister(str(root), recursive=recursive, masks=masks)
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = FileOpener(dp, mode="rb")
return Mapper(dp, functools.partial(_prepare_sample, root=root, categories=categories)), categories
def _data_to_image_key(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["image"] = EncodedImage(sample.pop("data").data)
return sample
def from_image_folder(
root: Union[str, pathlib.Path],
*,
valid_extensions: Collection[str] = ("jpg", "jpeg", "png", "ppm", "bmp", "pgm", "tif", "tiff", "webp"),
**kwargs: Any,
) -> Tuple[IterDataPipe, List[str]]:
valid_extensions = [valid_extension for ext in valid_extensions for valid_extension in (ext.lower(), ext.upper())]
dp, categories = from_data_folder(root, valid_extensions=valid_extensions, **kwargs)
return Mapper(dp, _data_to_image_key), categories
import os
from typing import Optional
import torchvision._internally_replaced_utils as _iru
def home(root: Optional[str] = None) -> str:
if root is not None:
_iru._HOME = root
return _iru._HOME
root = os.getenv("TORCHVISION_DATASETS_HOME")
if root is not None:
return root
return _iru._HOME
def use_sharded_dataset(use: Optional[bool] = None) -> bool:
if use is not None:
_iru._USE_SHARDED_DATASETS = use
return _iru._USE_SHARDED_DATASETS
use = os.getenv("TORCHVISION_SHARDED_DATASETS")
if use is not None:
return use == "1"
return _iru._USE_SHARDED_DATASETS
# type: ignore
import argparse
import collections.abc
import contextlib
import inspect
import itertools
import os
import os.path
import pathlib
import shutil
import sys
import tempfile
import time
import unittest.mock
import warnings
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader_experimental import DataLoader2
from torchvision import datasets as legacy_datasets
from torchvision.datasets.utils import extract_archive
from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import PILToTensor
def main(
name,
*,
variant=None,
legacy=True,
new=True,
start=True,
iteration=True,
num_starts=3,
num_samples=10_000,
temp_root=None,
num_workers=0,
):
benchmarks = [
benchmark
for benchmark in DATASET_BENCHMARKS
if benchmark.name == name and (variant is None or benchmark.variant == variant)
]
if not benchmarks:
msg = f"No DatasetBenchmark available for dataset '{name}'"
if variant is not None:
msg += f" and variant '{variant}'"
raise ValueError(msg)
for benchmark in benchmarks:
print("#" * 80)
print(f"{benchmark.name}" + (f" ({benchmark.variant})" if benchmark.variant is not None else ""))
if legacy and start:
print(
"legacy",
"cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
)
print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts),
)
if legacy and iteration:
print(
"legacy",
"iteration",
Measurement.iterations_per_time(
benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples)
),
)
if new and start:
print(
"new",
"cold_start",
Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts),
)
if new and iteration:
print(
"new",
"iteration",
Measurement.iterations_per_time(
benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)
),
)
class DatasetBenchmark:
def __init__(
self,
name: str,
*,
variant=None,
legacy_cls=None,
new_config=None,
legacy_config_map=None,
legacy_special_options_map=None,
prepare_legacy_root=None,
):
self.name = name
self.variant = variant
self.new_raw_dataset = new_datasets._api.find(name)
self.legacy_cls = legacy_cls or self._find_legacy_cls()
if new_config is None:
new_config = self.new_raw_dataset.default_config
elif isinstance(new_config, dict):
new_config = self.new_raw_dataset.info.make_config(**new_config)
self.new_config = new_config
self.legacy_config_map = legacy_config_map
self.legacy_special_options_map = legacy_special_options_map or self._legacy_special_options_map
self.prepare_legacy_root = prepare_legacy_root
def new_dataset(self, *, num_workers=0):
return DataLoader2(new_datasets.load(self.name, **self.new_config), num_workers=num_workers)
def new_cold_start(self, *, num_workers):
def fn(timer):
with timer:
dataset = self.new_dataset(num_workers=num_workers)
next(iter(dataset))
return fn
def new_iteration(self, *, num_samples, num_workers):
def fn(timer):
dataset = self.new_dataset(num_workers=num_workers)
num_sample = 0
with timer:
for _ in dataset:
num_sample += 1
if num_sample == num_samples:
break
return num_sample
return fn
def suppress_output(self):
@contextlib.contextmanager
def context_manager():
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull):
yield
return context_manager()
def legacy_dataset(self, root, *, num_workers=0, download=None):
legacy_config = self.legacy_config_map(self, root) if self.legacy_config_map else dict()
special_options = self.legacy_special_options_map(self)
if "download" in special_options and download is not None:
special_options["download"] = download
with self.suppress_output():
return DataLoader(
self.legacy_cls(legacy_config.pop("root", str(root)), **legacy_config, **special_options),
shuffle=True,
num_workers=num_workers,
)
@contextlib.contextmanager
def patch_download_and_integrity_checks(self):
patches = [
("download_url", dict()),
("download_file_from_google_drive", dict()),
("check_integrity", dict(new=lambda path, md5=None: os.path.isfile(path))),
]
dataset_module = sys.modules[self.legacy_cls.__module__]
utils_module = legacy_datasets.utils
with contextlib.ExitStack() as stack:
for name, patch_kwargs in patches:
patch_module = dataset_module if name in dir(dataset_module) else utils_module
stack.enter_context(unittest.mock.patch(f"{patch_module.__name__}.{name}", **patch_kwargs))
yield stack
def _find_resource_file_names(self):
info = self.new_raw_dataset.info
valid_options = info._valid_options
file_names = set()
for options in (
dict(zip(valid_options.keys(), values)) for values in itertools.product(*valid_options.values())
):
resources = self.new_raw_dataset.resources(info.make_config(**options))
file_names.update([resource.file_name for resource in resources])
return file_names
@contextlib.contextmanager
def legacy_root(self, temp_root):
new_root = pathlib.Path(new_datasets.home()) / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
warnings.warn(
"The temporary root directory for the legacy dataset was created on a different storage device than "
"the raw data that is used by the new dataset. If the devices have different I/O stats, this will "
"distort the benchmark. You can use the '--temp-root' flag to relocate the root directory of the "
"temporary directories.",
RuntimeWarning,
)
try:
for file_name in self._find_resource_file_names():
(legacy_root / file_name).symlink_to(new_root / file_name)
if self.prepare_legacy_root:
self.prepare_legacy_root(self, legacy_root)
with self.patch_download_and_integrity_checks():
yield legacy_root
finally:
shutil.rmtree(legacy_root)
def legacy_cold_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers)
next(iter(dataset))
return fn
def legacy_warm_start(self, temp_root, *, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
self.legacy_dataset(root, num_workers=num_workers)
with timer:
dataset = self.legacy_dataset(root, num_workers=num_workers, download=False)
next(iter(dataset))
return fn
def legacy_iteration(self, temp_root, *, num_samples, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer:
for num_sample, _ in enumerate(dataset, 1):
if num_sample == num_samples:
break
return num_sample
return fn
def _find_legacy_cls(self):
legacy_clss = {
name.lower(): dataset_class
for name, dataset_class in legacy_datasets.__dict__.items()
if isinstance(dataset_class, type) and issubclass(dataset_class, legacy_datasets.VisionDataset)
}
try:
return legacy_clss[self.name]
except KeyError as error:
raise RuntimeError(
f"Can't determine the legacy dataset class for '{self.name}' automatically. "
f"Please set the 'legacy_cls' keyword argument manually."
) from error
_SPECIAL_KWARGS = {
"transform",
"target_transform",
"transforms",
"download",
}
@staticmethod
def _legacy_special_options_map(benchmark):
available_parameters = set()
for cls in benchmark.legacy_cls.__mro__:
if cls is legacy_datasets.VisionDataset:
break
available_parameters.update(inspect.signature(cls.__init__).parameters)
available_special_kwargs = benchmark._SPECIAL_KWARGS.intersection(available_parameters)
special_options = dict()
if "download" in available_special_kwargs:
special_options["download"] = True
if "transform" in available_special_kwargs:
special_options["transform"] = PILToTensor()
if "target_transform" in available_special_kwargs:
special_options["target_transform"] = torch.tensor
elif "transforms" in available_special_kwargs:
special_options["transforms"] = JointTransform(PILToTensor(), PILToTensor())
return special_options
class Measurement:
@classmethod
def time(cls, fn, *, number):
results = Measurement._timeit(fn, number=number)
times = torch.tensor(tuple(zip(*results))[1])
return cls._format(times, unit="s")
@classmethod
def iterations_per_time(cls, fn):
num_samples, time = Measurement._timeit(fn, number=1)[0]
iterations_per_second = torch.tensor(num_samples) / torch.tensor(time)
return cls._format(iterations_per_second, unit="it/s")
class Timer:
def __init__(self):
self._start = None
self._stop = None
def __enter__(self):
self._start = time.perf_counter()
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop = time.perf_counter()
@property
def delta(self):
if self._start is None:
raise RuntimeError()
elif self._stop is None:
raise RuntimeError()
return self._stop - self._start
@classmethod
def _timeit(cls, fn, number):
results = []
for _ in range(number):
timer = cls.Timer()
output = fn(timer)
results.append((output, timer.delta))
return results
@classmethod
def _format(cls, measurements, *, unit):
measurements = torch.as_tensor(measurements).to(torch.float64).flatten()
if measurements.numel() == 1:
# TODO format that into engineering format
return f"{float(measurements):.3f} {unit}"
mean, std = Measurement._compute_mean_and_std(measurements)
# TODO format that into engineering format
return f"{mean:.3f} ± {std:.3f} {unit}"
@classmethod
def _compute_mean_and_std(cls, t):
mean = float(t.mean())
std = float(t.std(0, unbiased=t.numel() > 1))
return mean, std
def no_split(benchmark, root):
legacy_config = dict(benchmark.new_config)
del legacy_config["split"]
return legacy_config
def bool_split(name="train"):
def legacy_config_map(benchmark, root):
legacy_config = dict(benchmark.new_config)
legacy_config[name] = legacy_config.pop("split") == "train"
return legacy_config
return legacy_config_map
def base_folder(rel_folder=None):
if rel_folder is None:
def rel_folder(benchmark):
return benchmark.name
elif not callable(rel_folder):
name = rel_folder
def rel_folder(_):
return name
def prepare_legacy_root(benchmark, root):
files = list(root.glob("*"))
folder = root / rel_folder(benchmark)
folder.mkdir(parents=True)
for file in files:
shutil.move(str(file), str(folder))
return folder
return prepare_legacy_root
class JointTransform:
def __init__(self, *transforms):
self.transforms = transforms
def __call__(self, *inputs):
if len(inputs) == 1 and isinstance(inputs, collections.abc.Sequence):
inputs = inputs[0]
if len(inputs) != len(self.transforms):
raise RuntimeError(
f"The number of inputs and transforms mismatches: {len(inputs)} != {len(self.transforms)}."
)
return tuple(transform(input) for transform, input in zip(self.transforms, inputs))
def caltech101_legacy_config_map(benchmark, root):
legacy_config = no_split(benchmark, root)
# The new dataset always returns the category and annotation
legacy_config["target_type"] = ("category", "annotation")
return legacy_config
mnist_base_folder = base_folder(lambda benchmark: pathlib.Path(benchmark.legacy_cls.__name__) / "raw")
def mnist_legacy_config_map(benchmark, root):
return dict(train=benchmark.new_config.split == "train")
def emnist_prepare_legacy_root(benchmark, root):
folder = mnist_base_folder(benchmark, root)
shutil.move(str(folder / "emnist-gzip.zip"), str(folder / "gzip.zip"))
return folder
def emnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["split"] = benchmark.new_config.image_set.replace("_", "").lower()
return legacy_config
def qmnist_legacy_config_map(benchmark, root):
legacy_config = mnist_legacy_config_map(benchmark, root)
legacy_config["what"] = benchmark.new_config.split
# The new dataset always returns the full label
legacy_config["compat"] = False
return legacy_config
def coco_legacy_config_map(benchmark, root):
images, _ = benchmark.new_raw_dataset.resources(benchmark.new_config)
return dict(
root=str(root / pathlib.Path(images.file_name).stem),
annFile=str(
root / "annotations" / f"{benchmark.variant}_{benchmark.new_config.split}{benchmark.new_config.year}.json"
),
)
def coco_prepare_legacy_root(benchmark, root):
images, annotations = benchmark.new_raw_dataset.resources(benchmark.new_config)
extract_archive(str(root / images.file_name))
extract_archive(str(root / annotations.file_name))
DATASET_BENCHMARKS = [
DatasetBenchmark(
"caltech101",
legacy_config_map=caltech101_legacy_config_map,
prepare_legacy_root=base_folder(),
legacy_special_options_map=lambda config: dict(
download=True,
transform=PILToTensor(),
target_transform=JointTransform(torch.tensor, torch.tensor),
),
),
DatasetBenchmark(
"caltech256",
legacy_config_map=no_split,
prepare_legacy_root=base_folder(),
),
DatasetBenchmark(
"celeba",
prepare_legacy_root=base_folder(),
legacy_config_map=lambda benchmark: dict(
split="valid" if benchmark.new_config.split == "val" else benchmark.new_config.split,
# The new dataset always returns all annotations
target_type=("attr", "identity", "bbox", "landmarks"),
),
),
DatasetBenchmark(
"cifar10",
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"cifar100",
legacy_config_map=bool_split(),
),
DatasetBenchmark(
"emnist",
prepare_legacy_root=emnist_prepare_legacy_root,
legacy_config_map=emnist_legacy_config_map,
),
DatasetBenchmark(
"fashionmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"kmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"mnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"qmnist",
prepare_legacy_root=mnist_base_folder,
legacy_config_map=mnist_legacy_config_map,
),
DatasetBenchmark(
"sbd",
legacy_cls=legacy_datasets.SBDataset,
legacy_config_map=lambda benchmark: dict(
image_set=benchmark.new_config.split,
mode="boundaries" if benchmark.new_config.boundaries else "segmentation",
),
legacy_special_options_map=lambda benchmark: dict(
download=True,
transforms=JointTransform(
PILToTensor(), torch.tensor if benchmark.new_config.boundaries else PILToTensor()
),
),
),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet),
DatasetBenchmark(
"coco",
variant="instances",
legacy_cls=legacy_datasets.CocoDetection,
new_config=dict(split="train", annotations="instances"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
DatasetBenchmark(
"coco",
variant="captions",
legacy_cls=legacy_datasets.CocoCaptions,
new_config=dict(split="train", annotations="captions"),
legacy_config_map=coco_legacy_config_map,
prepare_legacy_root=coco_prepare_legacy_root,
legacy_special_options_map=lambda benchmark: dict(transform=PILToTensor(), target_transform=None),
),
]
def parse_args(argv=None):
parser = argparse.ArgumentParser(
prog="torchvision.prototype.datasets.benchmark.py",
description="Utility to benchmark new datasets against their legacy variants.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"--variant", help="Variant of the dataset. If omitted all available variants will be benchmarked."
)
parser.add_argument(
"-n",
"--num-starts",
type=int,
default=3,
help="Number of warm and cold starts of each benchmark. Default to 3.",
)
parser.add_argument(
"-N",
"--num-samples",
type=int,
default=10_000,
help="Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000.",
)
parser.add_argument(
"--nl",
"--no-legacy",
dest="legacy",
action="store_false",
help="Skip legacy benchmarks.",
)
parser.add_argument(
"--nn",
"--no-new",
dest="new",
action="store_false",
help="Skip new benchmarks.",
)
parser.add_argument(
"--ns",
"--no-start",
dest="start",
action="store_false",
help="Skip start benchmarks.",
)
parser.add_argument(
"--ni",
"--no-iteration",
dest="iteration",
action="store_false",
help="Skip iteration benchmarks.",
)
parser.add_argument(
"-t",
"--temp-root",
type=pathlib.Path,
help=(
"Root of the temporary legacy root directories. Use this if your system default temporary directory is on "
"another storage device as the raw data to avoid distortions due to differing I/O stats."
),
)
parser.add_argument(
"-j",
"--num-workers",
type=int,
default=0,
help=(
"Number of subprocesses used to load the data. Setting this to 0 (default) will load all data in the main "
"process and thus disable multi-processing."
),
)
return parser.parse_args(argv or sys.argv[1:])
if __name__ == "__main__":
args = parse_args()
try:
main(
args.name,
variant=args.variant,
legacy=args.legacy,
new=args.new,
start=args.start,
iteration=args.iteration,
num_starts=args.num_starts,
num_samples=args.num_samples,
temp_root=args.temp_root,
num_workers=args.num_workers,
)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
# type: ignore
import argparse
import csv
import sys
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
for name in names:
path = BUILTIN_DIR / f"{name}.categories"
if path.exists() and not force:
continue
dataset = datasets.load(name)
try:
categories = dataset._generate_categories()
except NotImplementedError:
continue
with open(path, "w") as file:
writer = csv.writer(file, lineterminator="\n")
for category in categories:
writer.writerow((category,) if isinstance(category, str) else category)
def parse_args(argv=None):
parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py")
parser.add_argument(
"names",
nargs="*",
type=str,
help="Names of datasets to generate category files for. If omitted, all datasets will be used.",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Force regeneration of category files.",
)
args = parser.parse_args(argv or sys.argv[1:])
if not args.names:
args.names = datasets.list_datasets()
return args
if __name__ == "__main__":
args = parse_args()
try:
main(*args.names, force=args.force)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
from . import _internal # usort: skip
from ._dataset import Dataset
from ._encoded import EncodedData, EncodedImage
from ._resource import GDriveResource, HttpResource, KaggleDownloadResource, ManualDownloadResource, OnlineResource
import abc
import importlib
import pathlib
from typing import Any, Collection, Dict, Iterator, List, Optional, Sequence, Union
from torchdata.datapipes.iter import IterDataPipe
from torchvision.datasets.utils import verify_str_arg
from ._resource import OnlineResource
class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC):
@staticmethod
def _verify_str_arg(
value: str,
arg: Optional[str] = None,
valid_values: Optional[Collection[str]] = None,
*,
custom_msg: Optional[str] = None,
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)
def __init__(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
) -> None:
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from None
self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
]
self._dp = self._datapipe(resources)
def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp
@abc.abstractmethod
def _resources(self) -> List[OnlineResource]:
pass
@abc.abstractmethod
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
pass
@abc.abstractmethod
def __len__(self) -> int:
pass
def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError
from __future__ import annotations
import os
import sys
from typing import Any, BinaryIO, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer
from torchvision.tv_tensors._tv_tensor import TVTensor
D = TypeVar("D", bound="EncodedData")
class EncodedData(TVTensor):
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return cls._wrap(tensor)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return cls._wrap(tensor)
@classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
encoded_data = cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
file.close()
return encoded_data
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
with open(path, "rb") as file:
return cls.from_file(file, **kwargs)
class EncodedImage(EncodedData):
# TODO: Use @functools.cached_property if we can depend on Python 3.8
@property
def spatial_size(self) -> Tuple[int, int]:
if not hasattr(self, "_spatial_size"):
with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image:
self._spatial_size = image.height, image.width
return self._spatial_size
import csv
import functools
import pathlib
import pickle
from typing import Any, BinaryIO, Callable, Dict, IO, Iterator, List, Sequence, Sized, Tuple, TypeVar, Union
import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
from torchvision.prototype.utils._internal import fromfile
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"read_mat",
"MappingIterator",
"getitem",
"path_accessor",
"path_comparator",
"read_flo",
"hint_sharding",
"hint_shuffling",
]
K = TypeVar("K")
D = TypeVar("D")
# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
try:
import scipy.io as sio
except ImportError as error:
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
data = sio.loadmat(buffer, **kwargs)
buffer.close()
return data
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
self.drop_key = drop_key
def __iter__(self) -> Iterator[Union[Tuple[K, D], D]]:
for mapping in self.datapipe:
yield from iter(mapping.values() if self.drop_key else mapping.items())
def _getitem_closure(obj: Any, *, items: Sequence[Any]) -> Any:
for item in items:
obj = obj[item]
return obj
def getitem(*items: Any) -> Callable[[Any], Any]:
return functools.partial(_getitem_closure, items=items)
def _getattr_closure(obj: Any, *, attrs: Sequence[str]) -> Any:
for attr in attrs:
obj = getattr(obj, attr)
return obj
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> Any:
return _getattr_closure(path, attrs=name.split("."))
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
return getter(pathlib.Path(data[0]))
def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]:
if isinstance(getter, str):
getter = functools.partial(_path_attribute_accessor, name=getter)
return functools.partial(_path_accessor_closure, getter=getter)
def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool:
return accessor(data) == value
def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]:
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
class PicklerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None:
self.source_datapipe = source_datapipe
def __iter__(self) -> Iterator[Any]:
for _, fobj in self.source_datapipe:
data = pickle.load(fobj)
for _, d in enumerate(data):
yield d
class SharderDataPipe(ShardingFilter):
def __init__(self, source_datapipe: IterDataPipe) -> None:
super().__init__(source_datapipe)
self.rank = 0
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.apply_sharding(self.world_size, self.rank)
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_id = self.rank
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_id + worker_info.id * num_workers
num_workers *= worker_info.num_workers
self.apply_sharding(num_workers, worker_id)
yield from super().__iter__()
class TakerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe, num_take: int) -> None:
super().__init__()
self.source_datapipe = source_datapipe
self.num_take = num_take
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.world_size = dist.get_world_size()
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers *= worker_info.num_workers
# TODO: this is weird as it drops more elements than it should
num_take = self.num_take // num_workers
for i, data in enumerate(self.source_datapipe):
if i < num_take:
yield data
else:
break
def __len__(self) -> int:
num_take = self.num_take // self.world_size
if isinstance(self.source_datapipe, Sized):
if len(self.source_datapipe) < num_take:
num_take = len(self.source_datapipe)
# TODO: might be weird to not take `num_workers` into account
return num_take
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe[Dict[str, Any]]:
dp = IoPathFileLister(root=root)
dp = SharderDataPipe(dp)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE)
dp = IoPathFileOpener(dp, mode="rb")
dp = PicklerDataPipe(dp)
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp
def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1))
def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter:
return ShardingFilter(datapipe)
def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]:
return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False)
def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]:
path = BUILTIN_DIR / f"{name}.categories"
with open(path, newline="") as file:
rows = list(csv.reader(file))
rows = [row[0] if len(row) == 1 else row for row in rows]
return rows
import abc
import hashlib
import itertools
import pathlib
from typing import Any, Callable, IO, Literal, NoReturn, Optional, Sequence, Set, Tuple, Union
from urllib.parse import urlparse
from torchdata.datapipes.iter import (
FileLister,
FileOpener,
IterableWrapper,
IterDataPipe,
RarArchiveLoader,
TarArchiveLoader,
ZipArchiveLoader,
)
from torchvision.datasets.utils import (
_decompress,
_detect_file_type,
_get_google_drive_file_id,
_get_redirect_url,
download_file_from_google_drive,
download_url,
extract_archive,
)
class OnlineResource(abc.ABC):
def __init__(
self,
*,
file_name: str,
sha256: Optional[str] = None,
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
if isinstance(preprocess, str):
if preprocess == "decompress":
preprocess = self._decompress
elif preprocess == "extract":
preprocess = self._extract
else:
raise ValueError(
f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string,"
f"but got {preprocess} instead."
)
self._preprocess = preprocess
@staticmethod
def _extract(file: pathlib.Path) -> None:
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
@staticmethod
def _decompress(file: pathlib.Path) -> None:
_decompress(str(file), remove_finished=True)
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
return FileOpener(FileLister(str(path), recursive=True), mode="rb")
dp = FileOpener(IterableWrapper((str(path),)), mode="rb")
archive_loader = self._guess_archive_loader(path)
if archive_loader:
dp = archive_loader(dp)
return dp
_ARCHIVE_LOADERS = {
".tar": TarArchiveLoader,
".zip": ZipArchiveLoader,
".rar": RarArchiveLoader,
}
def _guess_archive_loader(
self, path: pathlib.Path
) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]:
try:
_, archive_type, _ = _detect_file_type(path.name)
except RuntimeError:
return None
return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type]
def load(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False
) -> IterDataPipe[Tuple[str, IO]]:
root = pathlib.Path(root)
path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
stem = path.name.replace("".join(path.suffixes), "")
def find_candidates() -> Set[pathlib.Path]:
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
# test split of the stanford-cars dataset uses the files
# - cars_test.tgz
# - cars_test_annos_withlabels.mat
# Globbing for `"cars_test*"` picks up both.
candidates = {file for file in path.parent.glob(f"{stem}.*")}
folder_candidate = path.parent / stem
if folder_candidate.exists():
candidates.add(folder_candidate)
return candidates
candidates = find_candidates()
if not candidates:
self.download(root, skip_integrity_check=skip_integrity_check)
if self._preprocess is not None:
self._preprocess(path)
candidates = find_candidates()
# We use the path with the fewest suffixes. This gives us the
# extracted > decompressed > raw
# priority that we want for the best I/O performance.
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))
@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
pass
def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
root = pathlib.Path(root)
self._download(root)
path = root / self.file_name
if self.sha256 and not skip_integrity_check:
self._check_sha256(path)
return path
def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None:
hash = hashlib.sha256()
with open(path, "rb") as file:
while chunk := file.read(chunk_size):
hash.update(chunk)
sha256 = hash.hexdigest()
if sha256 != self.sha256:
raise RuntimeError(
f"After the download, the SHA256 checksum of {path} didn't match the expected one: "
f"{sha256} != {self.sha256}"
)
class HttpResource(OnlineResource):
def __init__(
self, url: str, *, file_name: Optional[str] = None, mirrors: Sequence[str] = (), **kwargs: Any
) -> None:
super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)
self.url = url
self.mirrors = mirrors
self._resolved = False
def resolve(self) -> OnlineResource:
if self._resolved:
return self
redirect_url = _get_redirect_url(self.url)
if redirect_url == self.url:
self._resolved = True
return self
meta = {
attr.lstrip("_"): getattr(self, attr)
for attr in (
"file_name",
"sha256",
"_preprocess",
)
}
gdrive_id = _get_google_drive_file_id(redirect_url)
if gdrive_id:
return GDriveResource(gdrive_id, **meta)
http_resource = HttpResource(redirect_url, **meta)
http_resource._resolved = True
return http_resource
def _download(self, root: pathlib.Path) -> None:
if not self._resolved:
return self.resolve()._download(root)
for url in itertools.chain((self.url,), self.mirrors):
try:
download_url(url, str(root), filename=self.file_name, md5=None)
# TODO: make this more precise
except Exception:
continue
return
else:
# TODO: make this more informative
raise RuntimeError("Download failed!")
class GDriveResource(OnlineResource):
def __init__(self, id: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.id = id
def _download(self, root: pathlib.Path) -> None:
download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None)
class ManualDownloadResource(OnlineResource):
def __init__(self, instructions: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.instructions = instructions
def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError(
f"The file {self.file_name} cannot be downloaded automatically. "
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)
class KaggleDownloadResource(ManualDownloadResource):
def __init__(self, challenge_url: str, *, file_name: str, **kwargs: Any) -> None:
instructions = "\n".join(
(
"1. Register and login at https://www.kaggle.com",
f"2. Navigate to {challenge_url}",
"3. Click 'Join Competition' and follow the instructions there",
"4. Navigate to the 'Data' tab",
f"5. Select {file_name} in the 'Data Explorer' and click the download button",
)
)
super().__init__(instructions, file_name=file_name, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment