Unverified Commit 48b1edff authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove prototype area for 0.19 (#8491)

parent f44f20cf
apple
aquarium_fish
baby
bear
beaver
bed
bee
beetle
bicycle
bottle
bowl
boy
bridge
bus
butterfly
camel
can
castle
caterpillar
cattle
chair
chimpanzee
clock
cloud
cockroach
couch
crab
crocodile
cup
dinosaur
dolphin
elephant
flatfish
forest
fox
girl
hamster
house
kangaroo
keyboard
lamp
lawn_mower
leopard
lion
lizard
lobster
man
maple_tree
motorcycle
mountain
mouse
mushroom
oak_tree
orange
orchid
otter
palm_tree
pear
pickup_truck
pine_tree
plain
plate
poppy
porcupine
possum
rabbit
raccoon
ray
road
rocket
rose
sea
seal
shark
shrew
skunk
skyscraper
snail
snake
spider
squirrel
streetcar
sunflower
sweet_pepper
table
tank
telephone
television
tiger
tractor
train
trout
tulip
turtle
wardrobe
whale
willow_tree
wolf
woman
worm
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "clevr"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CLEVR(Dataset):
"""
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
)
return [archive]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parent.name == "scenes":
return 1
else:
return None
def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool:
key, _ = data
return key == "scenes"
def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]:
return data, None
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
images_dp = Filter(images_dp, path_comparator("parent.name", self._split))
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
if self._split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp)
dp = IterKeyZipper(
images_dp,
scenes_dp,
key_fn=path_accessor("name"),
ref_key_fn=getitem("image_filename"),
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
for _, file in scenes_dp:
file.close()
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 70_000 if self._split == "train" else 15_000
__background__,N/A
person,person
bicycle,vehicle
car,vehicle
motorcycle,vehicle
airplane,vehicle
bus,vehicle
train,vehicle
truck,vehicle
boat,vehicle
traffic light,outdoor
fire hydrant,outdoor
N/A,N/A
stop sign,outdoor
parking meter,outdoor
bench,outdoor
bird,animal
cat,animal
dog,animal
horse,animal
sheep,animal
cow,animal
elephant,animal
bear,animal
zebra,animal
giraffe,animal
N/A,N/A
backpack,accessory
umbrella,accessory
N/A,N/A
N/A,N/A
handbag,accessory
tie,accessory
suitcase,accessory
frisbee,sports
skis,sports
snowboard,sports
sports ball,sports
kite,sports
baseball bat,sports
baseball glove,sports
skateboard,sports
surfboard,sports
tennis racket,sports
bottle,kitchen
N/A,N/A
wine glass,kitchen
cup,kitchen
fork,kitchen
knife,kitchen
spoon,kitchen
bowl,kitchen
banana,food
apple,food
sandwich,food
orange,food
broccoli,food
carrot,food
hot dog,food
pizza,food
donut,food
cake,food
chair,furniture
couch,furniture
potted plant,furniture
bed,furniture
N/A,N/A
dining table,furniture
N/A,N/A
N/A,N/A
toilet,furniture
N/A,N/A
tv,electronic
laptop,electronic
mouse,electronic
remote,electronic
keyboard,electronic
cell phone,electronic
microwave,appliance
oven,appliance
toaster,appliance
sink,appliance
refrigerator,appliance
N/A,N/A
book,indoor
clock,indoor
vase,indoor
scissors,indoor
teddy bear,indoor
hair drier,indoor
toothbrush,indoor
import pathlib
import re
from collections import defaultdict, OrderedDict
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
import torch
from torchdata.datapipes.iter import (
Demultiplexer,
Filter,
Grouper,
IterDataPipe,
IterKeyZipper,
JsonParser,
Mapper,
UnBatcher,
)
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
MappingIterator,
path_accessor,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes, Mask
from .._api import register_dataset, register_info
NAME = "coco"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*read_categories_file(NAME))
return dict(categories=categories, super_categories=super_categories)
@register_dataset(NAME)
class Coco(Dataset):
"""
- **homepage**: https://cocodataset.org/
- **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2017",
annotations: Optional[str] = "instances",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val"})
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
self._annotations = (
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
if annotations is not None
else None
)
info = _info()
categories, super_categories = info["categories"], info["super_categories"]
self._categories = categories
self._category_to_super_category = dict(zip(categories, super_categories))
super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
_IMAGES_CHECKSUMS = {
("2014", "train"): "ede4087e640bddba550e090eae701092534b554b42b05ac33f0300b984b31775",
("2014", "val"): "fe9be816052049c34717e077d9e34aa60814a55679f804cd043e3cbee3b9fde0",
("2017", "train"): "69a8bb58ea5f8f99d24875f21416de2e9ded3178e903f1f7603e283b9e06d929",
("2017", "val"): "4f7e2ccb2866ec5041993c9cf2a952bbed69647b115d0f74da7ce8f4bef82f05",
}
_META_URL_BASE = "http://images.cocodataset.org/annotations"
_META_CHECKSUMS = {
"2014": "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009",
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
}
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
)
meta = HttpResource(
f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
sha256=self._META_CHECKSUMS[self._year],
)
return [images, meta]
def _segmentation_to_mask(
self, segmentation: Any, *, is_crowd: bool, spatial_size: Tuple[int, int]
) -> torch.Tensor:
from pycocotools import mask
if is_crowd:
segmentation = mask.frPyObjects(segmentation, *spatial_size)
else:
segmentation = mask.merge(mask.frPyObjects(segmentation, *spatial_size))
return torch.from_numpy(mask.decode(segmentation)).to(torch.bool)
def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
spatial_size = (image_meta["height"], image_meta["width"])
labels = [ann["category_id"] for ann in anns]
return dict(
segmentations=Mask(
torch.stack(
[
self._segmentation_to_mask(
ann["segmentation"], is_crowd=ann["iscrowd"], spatial_size=spatial_size
)
for ann in anns
]
)
),
areas=torch.as_tensor([ann["area"] for ann in anns]),
crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBoxes(
[ann["bbox"] for ann in anns],
format="xywh",
spatial_size=spatial_size,
),
labels=Label(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
ann_ids=[ann["id"] for ann in anns],
)
def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
return dict(
captions=[ann["caption"] for ann in anns],
ann_ids=[ann["id"] for ann in anns],
)
_ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("captions", _decode_captions_ann),
]
)
_META_FILE_PATTERN = re.compile(
rf"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)
def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(
match
and match["split"] == self._split
and match["year"] == self._year
and match["annotations"] == self._annotations
)
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "images":
return 0
elif key == "annotations":
return 1
else:
return None
def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
)
def _prepare_sample(
self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
ann_data, image_data = data
anns, image_meta = ann_data
sample = self._prepare_image(image_data)
# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps
if self._annotations is None:
dp = hint_shuffling(images_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_image)
meta_dp = Filter(meta_dp, self._filter_meta_files)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
images_meta_dp, anns_meta_dp = Demultiplexer(
meta_dp,
2,
self._classify_meta,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
images_meta_dp = Mapper(images_meta_dp, getitem(1))
images_meta_dp = UnBatcher(images_meta_dp)
anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_shuffling(anns_meta_dp)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_dp = IterKeyZipper(
anns_meta_dp,
images_meta_dp,
key_fn=getitem(0, "image_id"),
ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
anns_dp,
images_dp,
key_fn=getitem(1, "file_name"),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
}[(self._split, self._year)][
self._annotations # type: ignore[index]
]
def _generate_categories(self) -> Tuple[Tuple[str, str]]:
self._annotations = "instances"
resources = self._resources()
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_meta_files)
dp = JsonParser(dp)
_, meta = next(iter(dp))
# List[Tuple[super_category, id, category]]
label_data = [cast(Tuple[str, int, str], tuple(info.values())) for info in meta["categories"]]
# COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the
# full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total
# number of categories is 90 rather than 91.
_, ids, _ = zip(*label_data)
missing_ids = set(range(1, max(ids) + 1)) - set(ids)
label_data.extend([("N/A", id, "N/A") for id in missing_ids])
# We also add a background category to be used during segmentation.
label_data.append(("N/A", 0, "__background__"))
super_categories, _, categories = zip(*sorted(label_data, key=lambda info: info[1]))
return cast(Tuple[Tuple[str, str]], tuple(zip(categories, super_categories)))
AD
AE
AF
AG
AI
AL
AM
AO
AQ
AR
AT
AU
AW
AX
AZ
BA
BB
BD
BE
BF
BG
BH
BJ
BM
BN
BO
BQ
BR
BS
BT
BW
BY
BZ
CA
CD
CF
CH
CI
CK
CL
CM
CN
CO
CR
CU
CV
CW
CY
CZ
DE
DK
DM
DO
DZ
EC
EE
EG
ES
ET
FI
FJ
FK
FO
FR
GA
GB
GD
GE
GF
GG
GH
GI
GL
GM
GP
GR
GS
GT
GU
GY
HK
HN
HR
HT
HU
ID
IE
IL
IM
IN
IQ
IR
IS
IT
JE
JM
JO
JP
KE
KG
KH
KN
KP
KR
KW
KY
KZ
LA
LB
LC
LI
LK
LR
LT
LU
LV
LY
MA
MC
MD
ME
MF
MG
MK
ML
MM
MN
MO
MQ
MR
MT
MU
MV
MW
MX
MY
MZ
NA
NC
NG
NI
NL
NO
NP
NZ
OM
PA
PE
PF
PG
PH
PK
PL
PR
PS
PT
PW
PY
QA
RE
RO
RS
RU
RW
SA
SB
SC
SD
SE
SG
SH
SI
SJ
SK
SL
SM
SN
SO
SS
SV
SX
SY
SZ
TG
TH
TJ
TL
TM
TN
TO
TR
TT
TW
TZ
UA
UG
US
UY
UZ
VA
VE
VG
VI
VN
VU
WS
XK
YE
ZA
ZM
ZW
import pathlib
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "country211"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Country211(Dataset):
"""
- **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
self._split_folder_name = "valid" if split == "val" else split
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
"https://openaipublic.azureedge.net/clip/data/country211.tgz",
sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c",
)
]
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
return pathlib.Path(data[0]).parent.parent.name == split
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name))
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 31_650,
"val": 10_550,
"test": 21_100,
}[self._split]
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
Black_footed_Albatross
Laysan_Albatross
Sooty_Albatross
Groove_billed_Ani
Crested_Auklet
Least_Auklet
Parakeet_Auklet
Rhinoceros_Auklet
Brewer_Blackbird
Red_winged_Blackbird
Rusty_Blackbird
Yellow_headed_Blackbird
Bobolink
Indigo_Bunting
Lazuli_Bunting
Painted_Bunting
Cardinal
Spotted_Catbird
Gray_Catbird
Yellow_breasted_Chat
Eastern_Towhee
Chuck_will_Widow
Brandt_Cormorant
Red_faced_Cormorant
Pelagic_Cormorant
Bronzed_Cowbird
Shiny_Cowbird
Brown_Creeper
American_Crow
Fish_Crow
Black_billed_Cuckoo
Mangrove_Cuckoo
Yellow_billed_Cuckoo
Gray_crowned_Rosy_Finch
Purple_Finch
Northern_Flicker
Acadian_Flycatcher
Great_Crested_Flycatcher
Least_Flycatcher
Olive_sided_Flycatcher
Scissor_tailed_Flycatcher
Vermilion_Flycatcher
Yellow_bellied_Flycatcher
Frigatebird
Northern_Fulmar
Gadwall
American_Goldfinch
European_Goldfinch
Boat_tailed_Grackle
Eared_Grebe
Horned_Grebe
Pied_billed_Grebe
Western_Grebe
Blue_Grosbeak
Evening_Grosbeak
Pine_Grosbeak
Rose_breasted_Grosbeak
Pigeon_Guillemot
California_Gull
Glaucous_winged_Gull
Heermann_Gull
Herring_Gull
Ivory_Gull
Ring_billed_Gull
Slaty_backed_Gull
Western_Gull
Anna_Hummingbird
Ruby_throated_Hummingbird
Rufous_Hummingbird
Green_Violetear
Long_tailed_Jaeger
Pomarine_Jaeger
Blue_Jay
Florida_Jay
Green_Jay
Dark_eyed_Junco
Tropical_Kingbird
Gray_Kingbird
Belted_Kingfisher
Green_Kingfisher
Pied_Kingfisher
Ringed_Kingfisher
White_breasted_Kingfisher
Red_legged_Kittiwake
Horned_Lark
Pacific_Loon
Mallard
Western_Meadowlark
Hooded_Merganser
Red_breasted_Merganser
Mockingbird
Nighthawk
Clark_Nutcracker
White_breasted_Nuthatch
Baltimore_Oriole
Hooded_Oriole
Orchard_Oriole
Scott_Oriole
Ovenbird
Brown_Pelican
White_Pelican
Western_Wood_Pewee
Sayornis
American_Pipit
Whip_poor_Will
Horned_Puffin
Common_Raven
White_necked_Raven
American_Redstart
Geococcyx
Loggerhead_Shrike
Great_Grey_Shrike
Baird_Sparrow
Black_throated_Sparrow
Brewer_Sparrow
Chipping_Sparrow
Clay_colored_Sparrow
House_Sparrow
Field_Sparrow
Fox_Sparrow
Grasshopper_Sparrow
Harris_Sparrow
Henslow_Sparrow
Le_Conte_Sparrow
Lincoln_Sparrow
Nelson_Sharp_tailed_Sparrow
Savannah_Sparrow
Seaside_Sparrow
Song_Sparrow
Tree_Sparrow
Vesper_Sparrow
White_crowned_Sparrow
White_throated_Sparrow
Cape_Glossy_Starling
Bank_Swallow
Barn_Swallow
Cliff_Swallow
Tree_Swallow
Scarlet_Tanager
Summer_Tanager
Artic_Tern
Black_Tern
Caspian_Tern
Common_Tern
Elegant_Tern
Forsters_Tern
Least_Tern
Green_tailed_Towhee
Brown_Thrasher
Sage_Thrasher
Black_capped_Vireo
Blue_headed_Vireo
Philadelphia_Vireo
Red_eyed_Vireo
Warbling_Vireo
White_eyed_Vireo
Yellow_throated_Vireo
Bay_breasted_Warbler
Black_and_white_Warbler
Black_throated_Blue_Warbler
Blue_winged_Warbler
Canada_Warbler
Cape_May_Warbler
Cerulean_Warbler
Chestnut_sided_Warbler
Golden_winged_Warbler
Hooded_Warbler
Kentucky_Warbler
Magnolia_Warbler
Mourning_Warbler
Myrtle_Warbler
Nashville_Warbler
Orange_crowned_Warbler
Palm_Warbler
Pine_Warbler
Prairie_Warbler
Prothonotary_Warbler
Swainson_Warbler
Tennessee_Warbler
Wilson_Warbler
Worm_eating_Warbler
Yellow_Warbler
Northern_Waterthrush
Louisiana_Waterthrush
Bohemian_Waxwing
Cedar_Waxwing
American_Three_toed_Woodpecker
Pileated_Woodpecker
Red_bellied_Woodpecker
Red_cockaded_Woodpecker
Red_headed_Woodpecker
Downy_Woodpecker
Bewick_Wren
Cactus_Wren
Carolina_Wren
House_Wren
Marsh_Wren
Rock_Wren
Winter_Wren
Common_Yellowthroat
import csv
import functools
import pathlib
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union
import torch
from torchdata.datapipes.iter import (
CSVDictParser,
CSVParser,
Demultiplexer,
Filter,
IterDataPipe,
IterKeyZipper,
LineReader,
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
csv.register_dialect("cub200", delimiter=" ")
NAME = "cub200"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class CUB200(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2011",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
self._year = self._verify_str_arg(year, "year", ("2010", "2011"))
self._categories = _info()["categories"]
super().__init__(
root,
# TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
# dependencies=("scipy",),
skip_integrity_check=skip_integrity_check,
)
def _resources(self) -> List[OnlineResource]:
if self._year == "2011":
archive = GDriveResource(
"1hbzc_P1FuxMkcabkgn9ZKinBwW683j45",
file_name="CUB_200_2011.tgz",
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
preprocess="decompress",
)
segmentations = GDriveResource(
"1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP",
file_name="segmentations.tgz",
sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f",
preprocess="decompress",
)
return [archive, segmentations]
else: # self._year == "2010"
split = GDriveResource(
"1vZuZPqha0JjmwkdaS_XtYryE3Jf5Q1AC",
file_name="lists.tgz",
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
preprocess="decompress",
)
images = GDriveResource(
"1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx",
file_name="images.tgz",
sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e",
preprocess="decompress",
)
anns = GDriveResource(
"16NsbTpMs5L6hT4hUJAmpW2u7wH326WTR",
file_name="annotations.tgz",
sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1",
preprocess="decompress",
)
return [split, images, anns]
def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.name == "train_test_split.txt":
return 1
elif path.name == "images.txt":
return 2
elif path.name == "bounding_boxes.txt":
return 3
else:
return None
def _2011_extract_file_name(self, rel_posix_path: str) -> str:
return rel_posix_path.rsplit("/", maxsplit=1)[1]
def _2011_filter_split(self, row: List[str]) -> bool:
_, split_id = row
return {
"0": "test",
"1": "train",
}[split_id] == self._split
def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name
def _2011_prepare_ann(
self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], spatial_size: Tuple[int, int]
) -> Dict[str, Any]:
_, (bounding_boxes_data, segmentation_data) = data
segmentation_path, segmentation_buffer = segmentation_data
return dict(
bounding_boxes=BoundingBoxes(
[float(part) for part in bounding_boxes_data[1:]], format="xywh", spatial_size=spatial_size
),
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
)
def _2010_split_key(self, data: str) -> str:
return data.rsplit("/", maxsplit=1)[1]
def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name, data
def _2010_prepare_ann(
self, data: Tuple[str, Tuple[str, BinaryIO]], spatial_size: Tuple[int, int]
) -> Dict[str, Any]:
_, (path, buffer) = data
content = read_mat(buffer)
return dict(
ann_path=path,
bounding_boxes=BoundingBoxes(
[int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")],
format="xyxy",
spatial_size=spatial_size,
),
segmentation=torch.as_tensor(content["seg"]),
)
def _prepare_sample(
self,
data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any],
*,
prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]],
) -> Dict[str, Any]:
data, anns_data = data
_, image_data = data
path, buffer = image_data
image = EncodedImage.from_file(buffer)
return dict(
prepare_ann_fn(anns_data, image.spatial_size),
image=image,
label=Label(
int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1,
categories=self._categories,
),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
prepare_ann_fn: Callable
if self._year == "2011":
archive_dp, segmentations_dp = resource_dps
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
image_files_dp = CSVParser(image_files_dp, dialect="cub200")
image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1)
image_files_map = IterToMapConverter(image_files_dp)
split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.__getitem__)
bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)
anns_dp = IterKeyZipper(
bounding_boxes_dp,
segmentations_dp,
key_fn=getitem(0),
ref_key_fn=self._2011_segmentation_key,
keep_key=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
prepare_ann_fn = self._2011_prepare_ann
else: # self._year == "2010"
split_dp, images_dp, anns_dp = resource_dps
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = Mapper(split_dp, self._2010_split_key)
anns_dp = Mapper(anns_dp, self._2010_anns_key)
prepare_ann_fn = self._2010_prepare_ann
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = IterKeyZipper(
split_dp,
images_dp,
getitem(),
path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
dp,
anns_dp,
getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn))
def __len__(self) -> int:
return {
("train", "2010"): 3_000,
("test", "2010"): 3_033,
("train", "2011"): 5_994,
("test", "2011"): 5_794,
}[(self._split, self._year)]
def _generate_categories(self) -> List[str]:
self._year = "2011"
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200")
return [row["category"].split(".")[1] for row in dp]
banded
blotchy
braided
bubbly
bumpy
chequered
cobwebbed
cracked
crosshatched
crystalline
dotted
fibrous
flecked
freckled
frilly
gauzy
grid
grooved
honeycombed
interlaced
knitted
lacelike
lined
marbled
matted
meshed
paisley
perforated
pitted
pleated
polka-dotted
porous
potholed
scaly
smeared
spiralled
sprinkled
stained
stratified
striped
studded
swirly
veined
waffled
woven
wrinkled
zigzagged
import enum
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "dtd"
class DTDDemux(enum.IntEnum):
SPLIT = 0
JOINT_CATEGORIES = 1
IMAGES = 2
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class DTD(Dataset):
"""DTD Dataset.
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
fold: int = 1,
skip_validation_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
if not (1 <= fold <= 10):
raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}")
self._fold = fold
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_validation_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
preprocess="decompress",
)
return [archive]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt":
return DTDDemux.JOINT_CATEGORIES
return DTDDemux.SPLIT
elif path.parents[1].name == "images":
return DTDDemux.IMAGES
else:
return None
def _image_key_fn(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
# The split files contain hardcoded posix paths for the images, e.g. banded/banded_0001.jpg
return str(path.relative_to(path.parents[1]).as_posix())
def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
(_, joint_categories_data), image_data = data
_, *joint_categories = joint_categories_data
path, buffer = image_data
category = pathlib.Path(path).parent.name
return dict(
joint_categories={category for category in joint_categories if category},
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
splits_dp, joint_categories_dp, images_dp = Demultiplexer(
archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt"))
splits_dp = LineReader(splits_dp, decode=True, return_path=False)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
joint_categories_dp = CSVParser(joint_categories_dp, delimiter=" ")
dp = IterKeyZipper(
splits_dp,
joint_categories_dp,
key_fn=getitem(),
ref_key_fn=getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == DTDDemux.IMAGES
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, self._filter_images)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
def __len__(self) -> int:
return 1_880 # All splits have the same length
import pathlib
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "eurosat"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
categories=(
"AnnualCrop",
"Forest",
"HerbaceousVegetation",
"Highway",
"Industrial",
"Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake",
)
)
@register_dataset(NAME)
class EuroSAT(Dataset):
"""EuroSAT Dataset.
homepage="https://github.com/phelber/eurosat",
"""
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
sha256="8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd",
)
]
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 27_000
import pathlib
from typing import Any, Dict, List, Union
import torch
from torchdata.datapipes.iter import CSVDictParser, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, KaggleDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
NAME = "fer2013"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"))
@register_dataset(NAME)
class FER2013(Dataset):
"""FER 2013 Dataset
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_CHECKSUMS = {
"train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10",
"test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3",
}
def _resources(self) -> List[OnlineResource]:
archive = KaggleDownloadResource(
"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
file_name=f"{self._split}.csv.zip",
sha256=self._CHECKSUMS[self._split],
)
return [archive]
def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
label_id = data.get("emotion")
return dict(
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self._categories) if label_id is not None else None,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 28_709 if self._split == "train" else 3_589
apple_pie
baby_back_ribs
baklava
beef_carpaccio
beef_tartare
beet_salad
beignets
bibimbap
bread_pudding
breakfast_burrito
bruschetta
caesar_salad
cannoli
caprese_salad
carrot_cake
ceviche
cheesecake
cheese_plate
chicken_curry
chicken_quesadilla
chicken_wings
chocolate_cake
chocolate_mousse
churros
clam_chowder
club_sandwich
crab_cakes
creme_brulee
croque_madame
cup_cakes
deviled_eggs
donuts
dumplings
edamame
eggs_benedict
escargots
falafel
filet_mignon
fish_and_chips
foie_gras
french_fries
french_onion_soup
french_toast
fried_calamari
fried_rice
frozen_yogurt
garlic_bread
gnocchi
greek_salad
grilled_cheese_sandwich
grilled_salmon
guacamole
gyoza
hamburger
hot_and_sour_soup
hot_dog
huevos_rancheros
hummus
ice_cream
lasagna
lobster_bisque
lobster_roll_sandwich
macaroni_and_cheese
macarons
miso_soup
mussels
nachos
omelette
onion_rings
oysters
pad_thai
paella
pancakes
panna_cotta
peking_duck
pho
pizza
pork_chop
poutine
prime_rib
pulled_pork_sandwich
ramen
ravioli
red_velvet_cake
risotto
samosa
sashimi
scallops
seaweed_salad
shrimp_and_grits
spaghetti_bolognese
spaghetti_carbonara
spring_rolls
steak
strawberry_shortcake
sushi
tacos
takoyaki
tiramisu
tuna_tartare
waffles
from pathlib import Path
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "food101"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Food101(Dataset):
"""Food 101 dataset
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
"""
def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz",
sha256="d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4",
preprocess="decompress",
)
]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parents[0].name == "meta":
return 1
else:
return None
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
id, (path, buffer) = data
return dict(
label=Label.from_category(id.split("/", 1)[0], categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _image_key(self, data: Tuple[str, Any]) -> str:
path = Path(data[0])
return path.relative_to(path.parents[1]).with_suffix("").as_posix()
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, split_dp = Demultiplexer(
archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
dp = IterKeyZipper(
split_dp,
images_dp,
key_fn=getitem(),
ref_key_fn=self._image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = LineReader(dp, decode=True, return_path=False)
return list(dp)
def __len__(self) -> int:
return 75_750 if self._split == "train" else 25_250
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_comparator,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
NAME = "gtsrb"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
categories=[f"{label:05d}" for label in range(43)],
)
@register_dataset(NAME)
class GTSRB(Dataset):
"""GTSRB Dataset
homepage="https://benchmark.ini.rub.de"
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
_URLS = {
"train": f"{_URL_ROOT}GTSRB-Training_fixed.zip",
"test": f"{_URL_ROOT}GTSRB_Final_Test_Images.zip",
"test_ground_truth": f"{_URL_ROOT}GTSRB_Final_Test_GT.zip",
}
_CHECKSUMS = {
"train": "df4144942083645bd60b594de348aa6930126c3e0e5de09e39611630abf8455a",
"test": "48ba6fab7e877eb64eaf8de99035b0aaecfbc279bee23e35deca4ac1d0a837fa",
"test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d",
}
def _resources(self) -> List[OnlineResource]:
rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])]
if self._split == "test":
rsrcs.append(
HttpResource(
self._URLS["test_ground_truth"],
sha256=self._CHECKSUMS["test_ground_truth"],
)
)
return rsrcs
def _classify_train_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.suffix == ".ppm":
return 0
elif path.suffix == ".csv":
return 1
else:
return None
def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[str, Any]:
(path, buffer), csv_info = data
label = int(csv_info["ClassId"])
bounding_boxes = BoundingBoxes(
[int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")],
format="xyxy",
spatial_size=(int(csv_info["Height"]), int(csv_info["Width"])),
)
return {
"path": path,
"image": EncodedImage.from_file(buffer),
"label": Label(label, categories=self._categories),
"bounding_boxes": bounding_boxes,
}
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
if self._split == "train":
images_dp, ann_dp = Demultiplexer(
resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
else:
images_dp, ann_dp = resource_dps
images_dp = Filter(images_dp, path_comparator("suffix", ".ppm"))
# The order of the image files in the .zip archives perfectly match the order of the entries in the
# (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper.
ann_dp = CSVDictParser(ann_dp, delimiter=";")
dp = Zipper(images_dp, ann_dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 26_640 if self._split == "train" else 12_630
tench,n01440764
goldfish,n01443537
great white shark,n01484850
tiger shark,n01491361
hammerhead,n01494475
electric ray,n01496331
stingray,n01498041
cock,n01514668
hen,n01514859
ostrich,n01518878
brambling,n01530575
goldfinch,n01531178
house finch,n01532829
junco,n01534433
indigo bunting,n01537544
robin,n01558993
bulbul,n01560419
jay,n01580077
magpie,n01582220
chickadee,n01592084
water ouzel,n01601694
kite,n01608432
bald eagle,n01614925
vulture,n01616318
great grey owl,n01622779
European fire salamander,n01629819
common newt,n01630670
eft,n01631663
spotted salamander,n01632458
axolotl,n01632777
bullfrog,n01641577
tree frog,n01644373
tailed frog,n01644900
loggerhead,n01664065
leatherback turtle,n01665541
mud turtle,n01667114
terrapin,n01667778
box turtle,n01669191
banded gecko,n01675722
common iguana,n01677366
American chameleon,n01682714
whiptail,n01685808
agama,n01687978
frilled lizard,n01688243
alligator lizard,n01689811
Gila monster,n01692333
green lizard,n01693334
African chameleon,n01694178
Komodo dragon,n01695060
African crocodile,n01697457
American alligator,n01698640
triceratops,n01704323
thunder snake,n01728572
ringneck snake,n01728920
hognose snake,n01729322
green snake,n01729977
king snake,n01734418
garter snake,n01735189
water snake,n01737021
vine snake,n01739381
night snake,n01740131
boa constrictor,n01742172
rock python,n01744401
Indian cobra,n01748264
green mamba,n01749939
sea snake,n01751748
horned viper,n01753488
diamondback,n01755581
sidewinder,n01756291
trilobite,n01768244
harvestman,n01770081
scorpion,n01770393
black and gold garden spider,n01773157
barn spider,n01773549
garden spider,n01773797
black widow,n01774384
tarantula,n01774750
wolf spider,n01775062
tick,n01776313
centipede,n01784675
black grouse,n01795545
ptarmigan,n01796340
ruffed grouse,n01797886
prairie chicken,n01798484
peacock,n01806143
quail,n01806567
partridge,n01807496
African grey,n01817953
macaw,n01818515
sulphur-crested cockatoo,n01819313
lorikeet,n01820546
coucal,n01824575
bee eater,n01828970
hornbill,n01829413
hummingbird,n01833805
jacamar,n01843065
toucan,n01843383
drake,n01847000
red-breasted merganser,n01855032
goose,n01855672
black swan,n01860187
tusker,n01871265
echidna,n01872401
platypus,n01873310
wallaby,n01877812
koala,n01882714
wombat,n01883070
jellyfish,n01910747
sea anemone,n01914609
brain coral,n01917289
flatworm,n01924916
nematode,n01930112
conch,n01943899
snail,n01944390
slug,n01945685
sea slug,n01950731
chiton,n01955084
chambered nautilus,n01968897
Dungeness crab,n01978287
rock crab,n01978455
fiddler crab,n01980166
king crab,n01981276
American lobster,n01983481
spiny lobster,n01984695
crayfish,n01985128
hermit crab,n01986214
isopod,n01990800
white stork,n02002556
black stork,n02002724
spoonbill,n02006656
flamingo,n02007558
little blue heron,n02009229
American egret,n02009912
bittern,n02011460
crane,n02012849
limpkin,n02013706
European gallinule,n02017213
American coot,n02018207
bustard,n02018795
ruddy turnstone,n02025239
red-backed sandpiper,n02027492
redshank,n02028035
dowitcher,n02033041
oystercatcher,n02037110
pelican,n02051845
king penguin,n02056570
albatross,n02058221
grey whale,n02066245
killer whale,n02071294
dugong,n02074367
sea lion,n02077923
Chihuahua,n02085620
Japanese spaniel,n02085782
Maltese dog,n02085936
Pekinese,n02086079
Shih-Tzu,n02086240
Blenheim spaniel,n02086646
papillon,n02086910
toy terrier,n02087046
Rhodesian ridgeback,n02087394
Afghan hound,n02088094
basset,n02088238
beagle,n02088364
bloodhound,n02088466
bluetick,n02088632
black-and-tan coonhound,n02089078
Walker hound,n02089867
English foxhound,n02089973
redbone,n02090379
borzoi,n02090622
Irish wolfhound,n02090721
Italian greyhound,n02091032
whippet,n02091134
Ibizan hound,n02091244
Norwegian elkhound,n02091467
otterhound,n02091635
Saluki,n02091831
Scottish deerhound,n02092002
Weimaraner,n02092339
Staffordshire bullterrier,n02093256
American Staffordshire terrier,n02093428
Bedlington terrier,n02093647
Border terrier,n02093754
Kerry blue terrier,n02093859
Irish terrier,n02093991
Norfolk terrier,n02094114
Norwich terrier,n02094258
Yorkshire terrier,n02094433
wire-haired fox terrier,n02095314
Lakeland terrier,n02095570
Sealyham terrier,n02095889
Airedale,n02096051
cairn,n02096177
Australian terrier,n02096294
Dandie Dinmont,n02096437
Boston bull,n02096585
miniature schnauzer,n02097047
giant schnauzer,n02097130
standard schnauzer,n02097209
Scotch terrier,n02097298
Tibetan terrier,n02097474
silky terrier,n02097658
soft-coated wheaten terrier,n02098105
West Highland white terrier,n02098286
Lhasa,n02098413
flat-coated retriever,n02099267
curly-coated retriever,n02099429
golden retriever,n02099601
Labrador retriever,n02099712
Chesapeake Bay retriever,n02099849
German short-haired pointer,n02100236
vizsla,n02100583
English setter,n02100735
Irish setter,n02100877
Gordon setter,n02101006
Brittany spaniel,n02101388
clumber,n02101556
English springer,n02102040
Welsh springer spaniel,n02102177
cocker spaniel,n02102318
Sussex spaniel,n02102480
Irish water spaniel,n02102973
kuvasz,n02104029
schipperke,n02104365
groenendael,n02105056
malinois,n02105162
briard,n02105251
kelpie,n02105412
komondor,n02105505
Old English sheepdog,n02105641
Shetland sheepdog,n02105855
collie,n02106030
Border collie,n02106166
Bouvier des Flandres,n02106382
Rottweiler,n02106550
German shepherd,n02106662
Doberman,n02107142
miniature pinscher,n02107312
Greater Swiss Mountain dog,n02107574
Bernese mountain dog,n02107683
Appenzeller,n02107908
EntleBucher,n02108000
boxer,n02108089
bull mastiff,n02108422
Tibetan mastiff,n02108551
French bulldog,n02108915
Great Dane,n02109047
Saint Bernard,n02109525
Eskimo dog,n02109961
malamute,n02110063
Siberian husky,n02110185
dalmatian,n02110341
affenpinscher,n02110627
basenji,n02110806
pug,n02110958
Leonberg,n02111129
Newfoundland,n02111277
Great Pyrenees,n02111500
Samoyed,n02111889
Pomeranian,n02112018
chow,n02112137
keeshond,n02112350
Brabancon griffon,n02112706
Pembroke,n02113023
Cardigan,n02113186
toy poodle,n02113624
miniature poodle,n02113712
standard poodle,n02113799
Mexican hairless,n02113978
timber wolf,n02114367
white wolf,n02114548
red wolf,n02114712
coyote,n02114855
dingo,n02115641
dhole,n02115913
African hunting dog,n02116738
hyena,n02117135
red fox,n02119022
kit fox,n02119789
Arctic fox,n02120079
grey fox,n02120505
tabby,n02123045
tiger cat,n02123159
Persian cat,n02123394
Siamese cat,n02123597
Egyptian cat,n02124075
cougar,n02125311
lynx,n02127052
leopard,n02128385
snow leopard,n02128757
jaguar,n02128925
lion,n02129165
tiger,n02129604
cheetah,n02130308
brown bear,n02132136
American black bear,n02133161
ice bear,n02134084
sloth bear,n02134418
mongoose,n02137549
meerkat,n02138441
tiger beetle,n02165105
ladybug,n02165456
ground beetle,n02167151
long-horned beetle,n02168699
leaf beetle,n02169497
dung beetle,n02172182
rhinoceros beetle,n02174001
weevil,n02177972
fly,n02190166
bee,n02206856
ant,n02219486
grasshopper,n02226429
cricket,n02229544
walking stick,n02231487
cockroach,n02233338
mantis,n02236044
cicada,n02256656
leafhopper,n02259212
lacewing,n02264363
dragonfly,n02268443
damselfly,n02268853
admiral,n02276258
ringlet,n02277742
monarch,n02279972
cabbage butterfly,n02280649
sulphur butterfly,n02281406
lycaenid,n02281787
starfish,n02317335
sea urchin,n02319095
sea cucumber,n02321529
wood rabbit,n02325366
hare,n02326432
Angora,n02328150
hamster,n02342885
porcupine,n02346627
fox squirrel,n02356798
marmot,n02361337
beaver,n02363005
guinea pig,n02364673
sorrel,n02389026
zebra,n02391049
hog,n02395406
wild boar,n02396427
warthog,n02397096
hippopotamus,n02398521
ox,n02403003
water buffalo,n02408429
bison,n02410509
ram,n02412080
bighorn,n02415577
ibex,n02417914
hartebeest,n02422106
impala,n02422699
gazelle,n02423022
Arabian camel,n02437312
llama,n02437616
weasel,n02441942
mink,n02442845
polecat,n02443114
black-footed ferret,n02443484
otter,n02444819
skunk,n02445715
badger,n02447366
armadillo,n02454379
three-toed sloth,n02457408
orangutan,n02480495
gorilla,n02480855
chimpanzee,n02481823
gibbon,n02483362
siamang,n02483708
guenon,n02484975
patas,n02486261
baboon,n02486410
macaque,n02487347
langur,n02488291
colobus,n02488702
proboscis monkey,n02489166
marmoset,n02490219
capuchin,n02492035
howler monkey,n02492660
titi,n02493509
spider monkey,n02493793
squirrel monkey,n02494079
Madagascar cat,n02497673
indri,n02500267
Indian elephant,n02504013
African elephant,n02504458
lesser panda,n02509815
giant panda,n02510455
barracouta,n02514041
eel,n02526121
coho,n02536864
rock beauty,n02606052
anemone fish,n02607072
sturgeon,n02640242
gar,n02641379
lionfish,n02643566
puffer,n02655020
abacus,n02666196
abaya,n02667093
academic gown,n02669723
accordion,n02672831
acoustic guitar,n02676566
aircraft carrier,n02687172
airliner,n02690373
airship,n02692877
altar,n02699494
ambulance,n02701002
amphibian,n02704792
analog clock,n02708093
apiary,n02727426
apron,n02730930
ashcan,n02747177
assault rifle,n02749479
backpack,n02769748
bakery,n02776631
balance beam,n02777292
balloon,n02782093
ballpoint,n02783161
Band Aid,n02786058
banjo,n02787622
bannister,n02788148
barbell,n02790996
barber chair,n02791124
barbershop,n02791270
barn,n02793495
barometer,n02794156
barrel,n02795169
barrow,n02797295
baseball,n02799071
basketball,n02802426
bassinet,n02804414
bassoon,n02804610
bathing cap,n02807133
bath towel,n02808304
bathtub,n02808440
beach wagon,n02814533
beacon,n02814860
beaker,n02815834
bearskin,n02817516
beer bottle,n02823428
beer glass,n02823750
bell cote,n02825657
bib,n02834397
bicycle-built-for-two,n02835271
bikini,n02837789
binder,n02840245
binoculars,n02841315
birdhouse,n02843684
boathouse,n02859443
bobsled,n02860847
bolo tie,n02865351
bonnet,n02869837
bookcase,n02870880
bookshop,n02871525
bottlecap,n02877765
bow,n02879718
bow tie,n02883205
brass,n02892201
brassiere,n02892767
breakwater,n02894605
breastplate,n02895154
broom,n02906734
bucket,n02909870
buckle,n02910353
bulletproof vest,n02916936
bullet train,n02917067
butcher shop,n02927161
cab,n02930766
caldron,n02939185
candle,n02948072
cannon,n02950826
canoe,n02951358
can opener,n02951585
cardigan,n02963159
car mirror,n02965783
carousel,n02966193
carpenter's kit,n02966687
carton,n02971356
car wheel,n02974003
cash machine,n02977058
cassette,n02978881
cassette player,n02979186
castle,n02980441
catamaran,n02981792
CD player,n02988304
cello,n02992211
cellular telephone,n02992529
chain,n02999410
chainlink fence,n03000134
chain mail,n03000247
chain saw,n03000684
chest,n03014705
chiffonier,n03016953
chime,n03017168
china cabinet,n03018349
Christmas stocking,n03026506
church,n03028079
cinema,n03032252
cleaver,n03041632
cliff dwelling,n03042490
cloak,n03045698
clog,n03047690
cocktail shaker,n03062245
coffee mug,n03063599
coffeepot,n03063689
coil,n03065424
combination lock,n03075370
computer keyboard,n03085013
confectionery,n03089624
container ship,n03095699
convertible,n03100240
corkscrew,n03109150
cornet,n03110669
cowboy boot,n03124043
cowboy hat,n03124170
cradle,n03125729
construction crane,n03126707
crash helmet,n03127747
crate,n03127925
crib,n03131574
Crock Pot,n03133878
croquet ball,n03134739
crutch,n03141823
cuirass,n03146219
dam,n03160309
desk,n03179701
desktop computer,n03180011
dial telephone,n03187595
diaper,n03188531
digital clock,n03196217
digital watch,n03197337
dining table,n03201208
dishrag,n03207743
dishwasher,n03207941
disk brake,n03208938
dock,n03216828
dogsled,n03218198
dome,n03220513
doormat,n03223299
drilling platform,n03240683
drum,n03249569
drumstick,n03250847
dumbbell,n03255030
Dutch oven,n03259280
electric fan,n03271574
electric guitar,n03272010
electric locomotive,n03272562
entertainment center,n03290653
envelope,n03291819
espresso maker,n03297495
face powder,n03314780
feather boa,n03325584
file,n03337140
fireboat,n03344393
fire engine,n03345487
fire screen,n03347037
flagpole,n03355925
flute,n03372029
folding chair,n03376595
football helmet,n03379051
forklift,n03384352
fountain,n03388043
fountain pen,n03388183
four-poster,n03388549
freight car,n03393912
French horn,n03394916
frying pan,n03400231
fur coat,n03404251
garbage truck,n03417042
gasmask,n03424325
gas pump,n03425413
goblet,n03443371
go-kart,n03444034
golf ball,n03445777
golfcart,n03445924
gondola,n03447447
gong,n03447721
gown,n03450230
grand piano,n03452741
greenhouse,n03457902
grille,n03459775
grocery store,n03461385
guillotine,n03467068
hair slide,n03476684
hair spray,n03476991
half track,n03478589
hammer,n03481172
hamper,n03482405
hand blower,n03483316
hand-held computer,n03485407
handkerchief,n03485794
hard disc,n03492542
harmonica,n03494278
harp,n03495258
harvester,n03496892
hatchet,n03498962
holster,n03527444
home theater,n03529860
honeycomb,n03530642
hook,n03532672
hoopskirt,n03534580
horizontal bar,n03535780
horse cart,n03538406
hourglass,n03544143
iPod,n03584254
iron,n03584829
jack-o'-lantern,n03590841
jean,n03594734
jeep,n03594945
jersey,n03595614
jigsaw puzzle,n03598930
jinrikisha,n03599486
joystick,n03602883
kimono,n03617480
knee pad,n03623198
knot,n03627232
lab coat,n03630383
ladle,n03633091
lampshade,n03637318
laptop,n03642806
lawn mower,n03649909
lens cap,n03657121
letter opener,n03658185
library,n03661043
lifeboat,n03662601
lighter,n03666591
limousine,n03670208
liner,n03673027
lipstick,n03676483
Loafer,n03680355
lotion,n03690938
loudspeaker,n03691459
loupe,n03692522
lumbermill,n03697007
magnetic compass,n03706229
mailbag,n03709823
mailbox,n03710193
maillot,n03710637
tank suit,n03710721
manhole cover,n03717622
maraca,n03720891
marimba,n03721384
mask,n03724870
matchstick,n03729826
maypole,n03733131
maze,n03733281
measuring cup,n03733805
medicine chest,n03742115
megalith,n03743016
microphone,n03759954
microwave,n03761084
military uniform,n03763968
milk can,n03764736
minibus,n03769881
miniskirt,n03770439
minivan,n03770679
missile,n03773504
mitten,n03775071
mixing bowl,n03775546
mobile home,n03776460
Model T,n03777568
modem,n03777754
monastery,n03781244
monitor,n03782006
moped,n03785016
mortar,n03786901
mortarboard,n03787032
mosque,n03788195
mosquito net,n03788365
motor scooter,n03791053
mountain bike,n03792782
mountain tent,n03792972
mouse,n03793489
mousetrap,n03794056
moving van,n03796401
muzzle,n03803284
nail,n03804744
neck brace,n03814639
necklace,n03814906
nipple,n03825788
notebook,n03832673
obelisk,n03837869
oboe,n03838899
ocarina,n03840681
odometer,n03841143
oil filter,n03843555
organ,n03854065
oscilloscope,n03857828
overskirt,n03866082
oxcart,n03868242
oxygen mask,n03868863
packet,n03871628
paddle,n03873416
paddlewheel,n03874293
padlock,n03874599
paintbrush,n03876231
pajama,n03877472
palace,n03877845
panpipe,n03884397
paper towel,n03887697
parachute,n03888257
parallel bars,n03888605
park bench,n03891251
parking meter,n03891332
passenger car,n03895866
patio,n03899768
pay-phone,n03902125
pedestal,n03903868
pencil box,n03908618
pencil sharpener,n03908714
perfume,n03916031
Petri dish,n03920288
photocopier,n03924679
pick,n03929660
pickelhaube,n03929855
picket fence,n03930313
pickup,n03930630
pier,n03933933
piggy bank,n03935335
pill bottle,n03937543
pillow,n03938244
ping-pong ball,n03942813
pinwheel,n03944341
pirate,n03947888
pitcher,n03950228
plane,n03954731
planetarium,n03956157
plastic bag,n03958227
plate rack,n03961711
plow,n03967562
plunger,n03970156
Polaroid camera,n03976467
pole,n03976657
police van,n03977966
poncho,n03980874
pool table,n03982430
pop bottle,n03983396
pot,n03991062
potter's wheel,n03992509
power drill,n03995372
prayer rug,n03998194
printer,n04004767
prison,n04005630
projectile,n04008634
projector,n04009552
puck,n04019541
punching bag,n04023962
purse,n04026417
quill,n04033901
quilt,n04033995
racer,n04037443
racket,n04039381
radiator,n04040759
radio,n04041544
radio telescope,n04044716
rain barrel,n04049303
recreational vehicle,n04065272
reel,n04067472
reflex camera,n04069434
refrigerator,n04070727
remote control,n04074963
restaurant,n04081281
revolver,n04086273
rifle,n04090263
rocking chair,n04099969
rotisserie,n04111531
rubber eraser,n04116512
rugby ball,n04118538
rule,n04118776
running shoe,n04120489
safe,n04125021
safety pin,n04127249
saltshaker,n04131690
sandal,n04133789
sarong,n04136333
sax,n04141076
scabbard,n04141327
scale,n04141975
school bus,n04146614
schooner,n04147183
scoreboard,n04149813
screen,n04152593
screw,n04153751
screwdriver,n04154565
seat belt,n04162706
sewing machine,n04179913
shield,n04192698
shoe shop,n04200800
shoji,n04201297
shopping basket,n04204238
shopping cart,n04204347
shovel,n04208210
shower cap,n04209133
shower curtain,n04209239
ski,n04228054
ski mask,n04229816
sleeping bag,n04235860
slide rule,n04238763
sliding door,n04239074
slot,n04243546
snorkel,n04251144
snowmobile,n04252077
snowplow,n04252225
soap dispenser,n04254120
soccer ball,n04254680
sock,n04254777
solar dish,n04258138
sombrero,n04259630
soup bowl,n04263257
space bar,n04264628
space heater,n04265275
space shuttle,n04266014
spatula,n04270147
speedboat,n04273569
spider web,n04275548
spindle,n04277352
sports car,n04285008
spotlight,n04286575
stage,n04296562
steam locomotive,n04310018
steel arch bridge,n04311004
steel drum,n04311174
stethoscope,n04317175
stole,n04325704
stone wall,n04326547
stopwatch,n04328186
stove,n04330267
strainer,n04332243
streetcar,n04335435
stretcher,n04336792
studio couch,n04344873
stupa,n04346328
submarine,n04347754
suit,n04350905
sundial,n04355338
sunglass,n04355933
sunglasses,n04356056
sunscreen,n04357314
suspension bridge,n04366367
swab,n04367480
sweatshirt,n04370456
swimming trunks,n04371430
swing,n04371774
switch,n04372370
syringe,n04376876
table lamp,n04380533
tank,n04389033
tape player,n04392985
teapot,n04398044
teddy,n04399382
television,n04404412
tennis ball,n04409515
thatch,n04417672
theater curtain,n04418357
thimble,n04423845
thresher,n04428191
throne,n04429376
tile roof,n04435653
toaster,n04442312
tobacco shop,n04443257
toilet seat,n04447861
torch,n04456115
totem pole,n04458633
tow truck,n04461696
toyshop,n04462240
tractor,n04465501
trailer truck,n04467665
tray,n04476259
trench coat,n04479046
tricycle,n04482393
trimaran,n04483307
tripod,n04485082
triumphal arch,n04486054
trolleybus,n04487081
trombone,n04487394
tub,n04493381
turnstile,n04501370
typewriter keyboard,n04505470
umbrella,n04507155
unicycle,n04509417
upright,n04515003
vacuum,n04517823
vase,n04522168
vault,n04523525
velvet,n04525038
vending machine,n04525305
vestment,n04532106
viaduct,n04532670
violin,n04536866
volleyball,n04540053
waffle iron,n04542943
wall clock,n04548280
wallet,n04548362
wardrobe,n04550184
warplane,n04552348
washbasin,n04553703
washer,n04554684
water bottle,n04557648
water jug,n04560804
water tower,n04562935
whiskey jug,n04579145
whistle,n04579432
wig,n04584207
window screen,n04589890
window shade,n04590129
Windsor tie,n04591157
wine bottle,n04591713
wing,n04592741
wok,n04596742
wooden spoon,n04597913
wool,n04599235
worm fence,n04604644
wreck,n04606251
yawl,n04612504
yurt,n04613696
web site,n06359193
comic book,n06596364
crossword puzzle,n06785654
street sign,n06794110
traffic light,n06874185
book jacket,n07248320
menu,n07565083
plate,n07579787
guacamole,n07583066
consomme,n07584110
hot pot,n07590611
trifle,n07613480
ice cream,n07614500
ice lolly,n07615774
French loaf,n07684084
bagel,n07693725
pretzel,n07695742
cheeseburger,n07697313
hotdog,n07697537
mashed potato,n07711569
head cabbage,n07714571
broccoli,n07714990
cauliflower,n07715103
zucchini,n07716358
spaghetti squash,n07716906
acorn squash,n07717410
butternut squash,n07717556
cucumber,n07718472
artichoke,n07718747
bell pepper,n07720875
cardoon,n07730033
mushroom,n07734744
Granny Smith,n07742313
strawberry,n07745940
orange,n07747607
lemon,n07749582
fig,n07753113
pineapple,n07753275
banana,n07753592
jackfruit,n07754684
custard apple,n07760859
pomegranate,n07768694
hay,n07802026
carbonara,n07831146
chocolate sauce,n07836838
dough,n07860988
meat loaf,n07871810
pizza,n07873807
potpie,n07875152
burrito,n07880968
red wine,n07892512
espresso,n07920052
cup,n07930864
eggnog,n07932039
alp,n09193705
bubble,n09229709
cliff,n09246464
coral reef,n09256479
geyser,n09288635
lakeside,n09332890
promontory,n09399592
sandbar,n09421951
seashore,n09428293
valley,n09468604
volcano,n09472597
ballplayer,n09835506
groom,n10148035
scuba diver,n10565667
rapeseed,n11879895
daisy,n11939491
yellow lady's slipper,n12057211
corn,n12144580
acorn,n12267677
hip,n12620546
buckeye,n12768682
coral fungus,n12985857
agaric,n12998815
gyromitra,n13037406
stinkhorn,n13040303
earthstar,n13044778
hen-of-the-woods,n13052670
bolete,n13054560
ear,n13133613
toilet tissue,n15075141
import enum
import pathlib
import re
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union
from torchdata.datapipes.iter import (
Demultiplexer,
Enumerator,
Filter,
IterDataPipe,
IterKeyZipper,
LineReader,
Mapper,
TarArchiveLoader,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, ManualDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "imagenet"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, wnids = zip(*read_categories_file(NAME))
return dict(categories=categories, wnids=wnids)
class ImageNetResource(ManualDownloadResource):
def __init__(self, **kwargs: Any) -> None:
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetDemux(enum.IntEnum):
META = 0
LABEL = 1
class CategoryAndWordNetIDExtractor(IterDataPipe):
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None:
self.datapipe = datapipe
def __iter__(self) -> Iterator[Tuple[str, str]]:
for _, stream in self.datapipe:
synsets = read_mat(stream, squeeze_me=True)["synsets"]
for _, wnid, category, _, num_children, *_ in synsets:
if num_children > 0:
# we are looking at a superclass that has no direct instance
continue
yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid
@register_dataset(NAME)
class ImageNet(Dataset):
"""
- **homepage**: https://www.image-net.org/
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})
info = _info()
categories, wnids = info["categories"], info["wnids"]
self._categories = categories
self._wnids = wnids
self._wnid_to_category = dict(zip(wnids, categories))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
}
def _resources(self) -> List[OnlineResource]:
name = "test_v10102019" if self._split == "test" else self._split
images = ImageNetResource(
file_name=f"ILSVRC2012_img_{name}.tar",
sha256=self._IMAGES_CHECKSUMS[name],
)
resources: List[OnlineResource] = [images]
if self._split == "val":
devkit = ImageNetResource(
file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
)
resources.append(devkit)
return resources
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0])
wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"]
label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
return (label, wnid), data
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
return None, data
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
return {
"meta.mat": ImageNetDemux.META,
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name)
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
def _val_test_image_key(self, path: pathlib.Path) -> int:
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index]
def _prepare_val_data(
self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]]
) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
label_data, image_data = data
_, wnid = label_data
label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
return (label, wnid), image_data
def _prepare_sample(
self,
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
label_data, (path, buffer) = data
return dict(
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
path=path,
image=EncodedImage.from_file(buffer),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
if self._split in {"train", "test"}:
dp = resource_dps[0]
# the train archive is a tar of tars
if self._split == "train":
dp = TarArchiveLoader(dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data)
else: # config.split == "val":
images_dp, devkit_dp = resource_dps
meta_dp, label_dp = Demultiplexer(
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
# We cannot use self._wnids here, since we use a different order than the dataset
meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
wnid_dp = Mapper(meta_dp, getitem(1))
wnid_dp = Enumerator(wnid_dp, 1)
wnid_map = IterToMapConverter(wnid_dp)
label_dp = LineReader(label_dp, decode=True, return_path=False)
label_dp = Mapper(label_dp, int)
label_dp = Mapper(label_dp, wnid_map.__getitem__)
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_shuffling(label_dp)
label_dp = hint_sharding(label_dp)
dp = IterKeyZipper(
label_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=path_accessor(self._val_test_image_key),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._prepare_val_data)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 1_281_167,
"val": 50_000,
"test": 100_000,
}[self._split]
def _filter_meta(self, data: Tuple[str, Any]) -> bool:
return self._classifiy_devkit(data) == ImageNetDemux.META
def _generate_categories(self) -> List[Tuple[str, ...]]:
self._split = "val"
resources = self._resources()
devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, self._filter_meta)
meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp))
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids
import abc
import functools
import operator
import pathlib
import string
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import torch
from torchdata.datapipes.iter import Decompressor, Demultiplexer, IterDataPipe, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, INFINITE_BUFFER_SIZE
from torchvision.prototype.tv_tensors import Label
from torchvision.prototype.utils._internal import fromfile
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
prod = functools.partial(functools.reduce, operator.mul)
class MNISTFileReader(IterDataPipe[torch.Tensor]):
_DTYPE_MAP = {
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}
def __init__(
self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int]
) -> None:
self.datapipe = datapipe
self.start = start
self.stop = stop
def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
try:
read = functools.partial(fromfile, file, byte_order="big")
magic = int(read(dtype=torch.int32, count=1))
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1
num_samples = int(read(dtype=torch.int32, count=1))
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
count = prod(shape) if shape else 1
start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples
if start:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
file.seek(num_bytes_per_value * count * start, 1)
for _ in range(stop - start):
yield read(dtype=dtype, count=count).reshape(shape)
finally:
file.close()
class _MNISTBase(Dataset):
_URL_BASE: Union[str, Sequence[str]]
@abc.abstractmethod
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
pass
def _resources(self) -> List[OnlineResource]:
(images_file, images_sha256), (
labels_file,
labels_sha256,
) = self._files_and_checksums()
url_bases = self._URL_BASE
if isinstance(url_bases, str):
url_bases = (url_bases,)
images_urls = [f"{url_base}/{images_file}" for url_base in url_bases]
images = HttpResource(images_urls[0], sha256=images_sha256, mirrors=images_urls[1:])
labels_urls = [f"{url_base}/{labels_file}" for url_base in url_bases]
labels = HttpResource(labels_urls[0], sha256=labels_sha256, mirrors=labels_urls[1:])
return [images, labels]
def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]:
return None, None
_categories: List[str]
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
image, label = data
return dict(
image=Image(image),
label=Label(label, dtype=torch.int64, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, labels_dp = resource_dps
start, stop = self.start_and_stop()
images_dp = Decompressor(images_dp)
images_dp = MNISTFileReader(images_dp, start=start, stop=stop)
labels_dp = Decompressor(labels_dp)
labels_dp = MNISTFileReader(labels_dp, start=start, stop=stop)
dp = Zipper(images_dp, labels_dp)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
@register_info("mnist")
def _mnist_info() -> Dict[str, Any]:
return dict(
categories=[str(label) for label in range(10)],
)
@register_dataset("mnist")
class MNIST(_MNISTBase):
"""
- **homepage**: http://yann.lecun.com/exdb/mnist
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE: Union[str, Sequence[str]] = (
"http://yann.lecun.com/exdb/mnist",
"https://ossci-datasets.s3.amazonaws.com/mnist",
)
_CHECKSUMS = {
"train-images-idx3-ubyte.gz": "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609",
"train-labels-idx1-ubyte.gz": "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c",
"t10k-images-idx3-ubyte.gz": "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6",
"t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6",
}
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "train" if self._split == "train" else "t10k"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
return (images_file, self._CHECKSUMS[images_file]), (
labels_file,
self._CHECKSUMS[labels_file],
)
_categories = _mnist_info()["categories"]
def __len__(self) -> int:
return 60_000 if self._split == "train" else 10_000
@register_info("fashionmnist")
def _fashionmnist_info() -> Dict[str, Any]:
return dict(
categories=[
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
],
)
@register_dataset("fashionmnist")
class FashionMNIST(MNIST):
"""
- **homepage**: https://github.com/zalandoresearch/fashion-mnist
"""
_URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com"
_CHECKSUMS = {
"train-images-idx3-ubyte.gz": "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84",
"train-labels-idx1-ubyte.gz": "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845",
"t10k-images-idx3-ubyte.gz": "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073",
"t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5",
}
_categories = _fashionmnist_info()["categories"]
@register_info("kmnist")
def _kmnist_info() -> Dict[str, Any]:
return dict(
categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"],
)
@register_dataset("kmnist")
class KMNIST(MNIST):
"""
- **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en
"""
_URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist"
_CHECKSUMS = {
"train-images-idx3-ubyte.gz": "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4",
"train-labels-idx1-ubyte.gz": "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17",
"t10k-images-idx3-ubyte.gz": "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5",
"t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c",
}
_categories = _kmnist_info()["categories"]
@register_info("emnist")
def _emnist_info() -> Dict[str, Any]:
return dict(
categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase),
)
@register_dataset("emnist")
class EMNIST(_MNISTBase):
"""
- **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
image_set: str = "Balanced",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
self._image_set = self._verify_str_arg(
image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST")
)
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST"
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}"
images_file = f"{prefix}-images-idx3-ubyte.gz"
labels_file = f"{prefix}-labels-idx1-ubyte.gz"
# Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them
return (images_file, ""), (labels_file, "")
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
f"{self._URL_BASE}/emnist-gzip.zip",
sha256="909a2a39c5e86bdd7662425e9b9c4a49bb582bf8d0edad427f3c3a9d0c6f7259",
)
]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
(images_file, _), (labels_file, _) = self._files_and_checksums()
if path.name == images_file:
return 0
elif path.name == labels_file:
return 1
else:
return None
_categories = _emnist_info()["categories"]
_LABEL_OFFSETS = {
38: 1,
39: 1,
40: 1,
41: 1,
42: 1,
43: 6,
44: 8,
45: 8,
46: 9,
}
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For
# example, since there is no 'c', 'd' corresponds to
# label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing),
# and at the same time corresponds to
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self._categories. Thus, we need to add 1 to the label to correct this.
if self._image_set in ("Balanced", "By_Merge"):
image, label = data
label += self._LABEL_OFFSETS.get(int(label), 0)
data = (image, label)
return super()._prepare_sample(data)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, labels_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
return super()._datapipe([images_dp, labels_dp])
def __len__(self) -> int:
return {
("train", "Balanced"): 112_800,
("train", "By_Merge"): 697_932,
("train", "By_Class"): 697_932,
("train", "Letters"): 124_800,
("train", "Digits"): 240_000,
("train", "MNIST"): 60_000,
("test", "Balanced"): 18_800,
("test", "By_Merge"): 116_323,
("test", "By_Class"): 116_323,
("test", "Letters"): 20_800,
("test", "Digits"): 40_000,
("test", "MNIST"): 10_000,
}[(self._split, self._image_set)]
@register_info("qmnist")
def _qmnist_info() -> Dict[str, Any]:
return dict(
categories=[str(label) for label in range(10)],
)
@register_dataset("qmnist")
class QMNIST(_MNISTBase):
"""
- **homepage**: https://github.com/facebookresearch/qmnist
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
_URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master"
_CHECKSUMS = {
"qmnist-train-images-idx3-ubyte.gz": "9e26a7bf1683614e065d7b76460ccd52807165b3f22561fb782bd9f38c52b51d",
"qmnist-train-labels-idx2-int.gz": "2c05dc77f6b916b38e455e97ab129a42a444f3dbef09b278a366f82904e0dd9f",
"qmnist-test-images-idx3-ubyte.gz": "43fc22bf7498b8fc98de98369d72f752d0deabc280a43a7bcc364ab19e57b375",
"qmnist-test-labels-idx2-int.gz": "9fbcbe594c3766fdf4f0b15c5165dc0d1e57ac604e01422608bb72c906030d06",
"xnist-images-idx3-ubyte.xz": "f075553993026d4359ded42208eff77a1941d3963c1eff49d6015814f15f0984",
"xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f",
}
def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]:
prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}"
suffix = "xz" if self._split == "nist" else "gz"
images_file = f"{prefix}-images-idx3-ubyte.{suffix}"
labels_file = f"{prefix}-labels-idx2-int.{suffix}"
return (images_file, self._CHECKSUMS[images_file]), (
labels_file,
self._CHECKSUMS[labels_file],
)
def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]:
start: Optional[int]
stop: Optional[int]
if self._split == "test10k":
start = 0
stop = 10000
elif self._split == "test50k":
start = 10000
stop = None
else:
start = stop = None
return start, stop
_categories = _emnist_info()["categories"]
def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
image, ann = data
label, *extra_anns = ann
sample = super()._prepare_sample((image, label))
sample.update(
dict(
zip(
("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"),
[int(value) for value in extra_anns[:5]],
)
)
)
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
return sample
def __len__(self) -> int:
return {
"train": 60_000,
"test": 60_000,
"test10k": 10_000,
"test50k": 50_000,
"nist": 402_953,
}[self._split]
Abyssinian
American Bulldog
American Pit Bull Terrier
Basset Hound
Beagle
Bengal
Birman
Bombay
Boxer
British Shorthair
Chihuahua
Egyptian Mau
English Cocker Spaniel
English Setter
German Shorthaired
Great Pyrenees
Havanese
Japanese Chin
Keeshond
Leonberger
Maine Coon
Miniature Pinscher
Newfoundland
Persian
Pomeranian
Pug
Ragdoll
Russian Blue
Saint Bernard
Samoyed
Scottish Terrier
Shiba Inu
Siamese
Sphynx
Staffordshire Bull Terrier
Wheaten Terrier
Yorkshire Terrier
import enum
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVDictParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "oxford-iiit-pet"
class OxfordIIITPetDemux(enum.IntEnum):
SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class OxfordIIITPet(Dataset):
"""Oxford IIIT Pet Dataset
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", {"trainval", "test"})
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
preprocess="decompress",
)
anns = HttpResource(
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91",
preprocess="decompress",
)
return [images, anns]
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return {
"annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name)
def _filter_images(self, data: Tuple[str, Any]) -> bool:
return pathlib.Path(data[0]).suffix == ".jpg"
def _filter_segmentations(self, data: Tuple[str, Any]) -> bool:
return not pathlib.Path(data[0]).name.startswith(".")
def _prepare_sample(
self, data: Tuple[Tuple[Dict[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]
) -> Dict[str, Any]:
ann_data, image_data = data
classification_data, segmentation_data = ann_data
segmentation_path, segmentation_buffer = segmentation_data
image_path, image_buffer = image_data
return dict(
label=Label(int(classification_data["label"]) - 1, categories=self._categories),
species="cat" if classification_data["species"] == "1" else "dog",
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._filter_images)
split_and_classification_dp, segmentations_dp = Demultiplexer(
anns_dp,
2,
self._classify_anns,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt"))
split_and_classification_dp = CSVDictParser(
split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" "
)
split_and_classification_dp = hint_shuffling(split_and_classification_dp)
split_and_classification_dp = hint_sharding(split_and_classification_dp)
segmentations_dp = Filter(segmentations_dp, self._filter_segmentations)
anns_dp = IterKeyZipper(
split_and_classification_dp,
segmentations_dp,
key_fn=getitem("image_id"),
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
anns_dp,
images_dp,
key_fn=getitem(0, "image_id"),
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_split_and_classification_anns)
dp = Filter(dp, path_comparator("name", "trainval.txt"))
dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ")
raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp}
raw_categories, _ = zip(
*sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1]))
)
return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories]
def __len__(self) -> int:
return 3_680 if self._split == "trainval" else 3_669
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment