Unverified Commit 055708d2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototypes for `Caltech(101|256)` datasets (#4510)

* add prototype for `Caltech256` dataset

* silence mypy
parent 4bf60863
...@@ -495,7 +495,7 @@ if __name__ == "__main__": ...@@ -495,7 +495,7 @@ if __name__ == "__main__":
# Package info # Package info
packages=find_packages(exclude=('test',)), packages=find_packages(exclude=('test',)),
package_data={ package_data={
package_name: ['*.dll', '*.dylib', '*.so'] package_name: ['*.dll', '*.dylib', '*.so', '*.categories']
}, },
zip_safe=False, zip_safe=False,
install_requires=requirements, install_requires=requirements,
......
...@@ -8,7 +8,7 @@ from torchvision.prototype.datasets import home ...@@ -8,7 +8,7 @@ from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils._internal import add_suggestion from torchvision.prototype.datasets.utils._internal import add_suggestion
from . import _builtin
DATASETS: Dict[str, Dataset] = {} DATASETS: Dict[str, Dataset] = {}
...@@ -17,6 +17,16 @@ def register(dataset: Dataset) -> None: ...@@ -17,6 +17,16 @@ def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset DATASETS[dataset.name] = dataset
for name, obj in _builtin.__dict__.items():
if (
not name.startswith("_")
and isinstance(obj, type)
and issubclass(obj, Dataset)
and obj is not Dataset
):
register(obj())
# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' # This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
def _list() -> List[str]: def _list() -> List[str]:
return sorted(DATASETS.keys()) return sorted(DATASETS.keys())
......
from .caltech import Caltech101, Caltech256
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import re
import numpy as np
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
Mapper,
TarArchiveReader,
Shuffler,
Filter,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat
HERE = pathlib.Path(__file__).parent
class Caltech101(Dataset):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"caltech101",
categories=HERE / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
)
anns = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8",
)
return [images, anns]
_IMAGES_NAME_PATTERN = re.compile(r"image_(?P<id>\d+)[.]jpg")
_ANNS_NAME_PATTERN = re.compile(r"annotation_(?P<id>\d+)[.]mat")
_ANNS_CATEGORY_MAP = {
"Faces_2": "Faces",
"Faces_3": "Faces_easy",
"Motorbikes_16": "Motorbikes",
"Airplanes_Side_2": "airplanes",
}
def _is_not_background_image(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.parent.name != "BACKGROUND_Google"
def _is_ann(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return bool(self._ANNS_NAME_PATTERN.match(path.name))
def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
path = pathlib.Path(data[0])
category = path.parent.name
id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
return category, id
def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
path = pathlib.Path(data[0])
category = path.parent.name
if category in self._ANNS_CATEGORY_MAP:
category = self._ANNS_CATEGORY_MAP[category]
id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
return category, id
def _collate_and_decode_sample(
self, data, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]:
key, image_data, ann_data = data
category, _ = key
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
label = self.info.categories.index(category)
image = decoder(image_buffer) if decoder else image_buffer
ann = read_mat(ann_buffer)
bbox = torch.as_tensor(ann["box_coord"].astype(np.int64))
contour = torch.as_tensor(ann["obj_contour"])
return dict(
category=category,
label=label,
image=image,
image_path=image_path,
bbox=bbox,
contour=contour,
ann_path=ann_path,
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
images_dp = Filter(images_dp, self._is_not_background_image)
# FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved
# images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = TarArchiveReader(anns_dp)
anns_dp = Filter(anns_dp, self._is_ann)
dp = KeyZipper(
images_dp,
anns_dp,
key_fn=self._images_key_fn,
ref_key_fn=self._anns_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = Filter(dp, self._is_not_background_image)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
create_categories_file(HERE, self.name, sorted(dir_names))
class Caltech256(Dataset):
@property
def info(self) -> DatasetInfo:
return DatasetInfo(
"caltech256",
categories=HERE / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e",
)
]
def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name != "RENAME2"
def _collate_and_decode_sample(
self,
data: Tuple[str, io.IOBase],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
path, buffer = data
dir_name = pathlib.Path(path).parent.name
label_str, category = dir_name.split(".")
label = torch.tensor(int(label_str))
return dict(label=label, category=category, image=decoder(buffer) if decoder else buffer)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp = Filter(dp, self._is_not_rogue_file)
# FIXME: add this after https://github.com/pytorch/pytorch/issues/65808 is resolved
# dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
categories = [name.split(".")[1] for name in sorted(dir_names)]
create_categories_file(HERE, self.name, categories)
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
Caltech101().generate_categories_file(root)
Caltech256().generate_categories_file(root)
Faces
Faces_easy
Leopards
Motorbikes
accordion
airplanes
anchor
ant
barrel
bass
beaver
binocular
bonsai
brain
brontosaurus
buddha
butterfly
camera
cannon
car_side
ceiling_fan
cellphone
chair
chandelier
cougar_body
cougar_face
crab
crayfish
crocodile
crocodile_head
cup
dalmatian
dollar_bill
dolphin
dragonfly
electric_guitar
elephant
emu
euphonium
ewer
ferry
flamingo
flamingo_head
garfield
gerenuk
gramophone
grand_piano
hawksbill
headphone
hedgehog
helicopter
ibis
inline_skate
joshua_tree
kangaroo
ketch
lamp
laptop
llama
lobster
lotus
mandolin
mayfly
menorah
metronome
minaret
nautilus
octopus
okapi
pagoda
panda
pigeon
pizza
platypus
pyramid
revolver
rhino
rooster
saxophone
schooner
scissors
scorpion
sea_horse
snoopy
soccer_ball
stapler
starfish
stegosaurus
stop_sign
strawberry
sunflower
tick
trilobite
umbrella
watch
water_lilly
wheelchair
wild_cat
windsor_chair
wrench
yin_yang
ak47
american-flag
backpack
baseball-bat
baseball-glove
basketball-hoop
bat
bathtub
bear
beer-mug
billiards
binoculars
birdbath
blimp
bonsai-101
boom-box
bowling-ball
bowling-pin
boxing-glove
brain-101
breadmaker
buddha-101
bulldozer
butterfly
cactus
cake
calculator
camel
cannon
canoe
car-tire
cartman
cd
centipede
cereal-box
chandelier-101
chess-board
chimp
chopsticks
cockroach
coffee-mug
coffin
coin
comet
computer-keyboard
computer-monitor
computer-mouse
conch
cormorant
covered-wagon
cowboy-hat
crab-101
desk-globe
diamond-ring
dice
dog
dolphin-101
doorknob
drinking-straw
duck
dumb-bell
eiffel-tower
electric-guitar-101
elephant-101
elk
ewer-101
eyeglasses
fern
fighter-jet
fire-extinguisher
fire-hydrant
fire-truck
fireworks
flashlight
floppy-disk
football-helmet
french-horn
fried-egg
frisbee
frog
frying-pan
galaxy
gas-pump
giraffe
goat
golden-gate-bridge
goldfish
golf-ball
goose
gorilla
grand-piano-101
grapes
grasshopper
guitar-pick
hamburger
hammock
harmonica
harp
harpsichord
hawksbill-101
head-phones
helicopter-101
hibiscus
homer-simpson
horse
horseshoe-crab
hot-air-balloon
hot-dog
hot-tub
hourglass
house-fly
human-skeleton
hummingbird
ibis-101
ice-cream-cone
iguana
ipod
iris
jesus-christ
joy-stick
kangaroo-101
kayak
ketch-101
killer-whale
knife
ladder
laptop-101
lathe
leopards-101
license-plate
lightbulb
light-house
lightning
llama-101
mailbox
mandolin
mars
mattress
megaphone
menorah-101
microscope
microwave
minaret
minotaur
motorbikes-101
mountain-bike
mushroom
mussels
necktie
octopus
ostrich
owl
palm-pilot
palm-tree
paperclip
paper-shredder
pci-card
penguin
people
pez-dispenser
photocopier
picnic-table
playing-card
porcupine
pram
praying-mantis
pyramid
raccoon
radio-telescope
rainbow
refrigerator
revolver-101
rifle
rotary-phone
roulette-wheel
saddle
saturn
school-bus
scorpion-101
screwdriver
segway
self-propelled-lawn-mower
sextant
sheet-music
skateboard
skunk
skyscraper
smokestack
snail
snake
sneaker
snowmobile
soccer-ball
socks
soda-can
spaghetti
speed-boat
spider
spoon
stained-glass
starfish-101
steering-wheel
stirrups
sunflower-101
superman
sushi
swan
swiss-army-knife
sword
syringe
tambourine
teapot
teddy-bear
teepee
telephone-box
tennis-ball
tennis-court
tennis-racket
theodolite
toaster
tomato
tombstone
top-hat
touring-bike
tower-pisa
traffic-light
treadmill
triceratops
tricycle
trilobite-101
tripod
t-shirt
tuning-fork
tweezer
umbrella-101
unicorn
vcr
video-projector
washing-machine
watch-101
waterfall
watermelon
welding-mask
wheelbarrow
windmill
wine-bottle
xylophone
yarmulke
yo-yo
zebra
airplanes-101
car-side-101
faces-easy-101
greyhound
tennis-shoes
toad
clutter
...@@ -8,5 +8,5 @@ from torchvision.transforms.functional import pil_to_tensor ...@@ -8,5 +8,5 @@ from torchvision.transforms.functional import pil_to_tensor
__all__ = ["pil"] __all__ = ["pil"]
def pil(file: io.IOBase, mode: str = "RGB") -> torch.Tensor: def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
return pil_to_tensor(PIL.Image.open(file).convert(mode.upper())) return pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper()))
...@@ -112,7 +112,7 @@ class DatasetInfo: ...@@ -112,7 +112,7 @@ class DatasetInfo:
categories = [str(label) for label in range(categories)] categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)): elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
categories = fh.readlines() categories = [line.strip() for line in fh]
self.categories = categories self.categories = categories
self.citation = citation self.citation = citation
......
import collections.abc import collections.abc
import difflib import difflib
from typing import Collection, Sequence, Callable import io
import pathlib
from typing import Collection, Sequence, Callable, Union, Any
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
"sequence_to_str", "sequence_to_str",
"add_suggestion", "add_suggestion",
"create_categories_file",
"read_mat"
] ]
# pseudo-infinite until a true infinite buffer is supported by all datapipes # pseudo-infinite until a true infinite buffer is supported by all datapipes
...@@ -44,3 +48,21 @@ def add_suggestion( ...@@ -44,3 +48,21 @@ def add_suggestion(
else alternative_hint(possibilities) else alternative_hint(possibilities)
) )
return f"{msg.strip()} {hint}" return f"{msg.strip()} {hint}"
def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[str]
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w") as fh:
fh.write("\n".join(categories) + "\n")
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try:
import scipy.io as sio
except ImportError as error:
raise ModuleNotFoundError(
"Package `scipy` is required to be installed to read .mat files."
) from error
return sio.loadmat(buffer, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment