"git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "5fc301aaee5edbccb02156f8081bb81240a34026"
Unverified Commit 673838f5 authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Removing prototype related things from release/0.14 branch (#6687)

* Remove test related to prototype

* Remove torchvision/prototype dir

* Remove references/depth/stereo because it depend on prototype

* Remove prototype related entries on mypy.ini

* Remove things related to prototype in pytest.ini

* clean setup.py from prototype

* Clean CI from prototype

* Remove unused expect file
parent 07ae61bf
from . import datasets, features, models, transforms, utils
try:
import torchdata
except ModuleNotFoundError:
raise ModuleNotFoundError(
"`torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). "
"You can install it with `pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu"
) from None
from . import utils
from ._home import home
# Load this last, since some parts depend on the above being loaded first
from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip
from ._folder import from_data_folder, from_image_folder
from ._builtin import *
import pathlib
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion
T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset])
BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}
def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]:
def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]:
BUILTIN_INFOS[name] = fn()
return fn
return wrapper
BUILTIN_DATASETS = {}
def register_dataset(name: str) -> Callable[[D], D]:
def wrapper(dataset_cls: D) -> D:
BUILTIN_DATASETS[name] = dataset_cls
return dataset_cls
return wrapper
def list_datasets() -> List[str]:
return sorted(BUILTIN_DATASETS.keys())
def find(dct: Dict[str, T], name: str) -> T:
name = name.lower()
try:
return dct[name]
except KeyError as error:
raise ValueError(
add_suggestion(
f"Unknown dataset '{name}'.",
word=name,
possibilities=dct.keys(),
alternative_hint=lambda _: (
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
),
)
) from error
def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
dataset_cls = find(BUILTIN_DATASETS, name)
if root is None:
root = pathlib.Path(home()) / name
return dataset_cls(root, **config)
# How to add new built-in prototype datasets
As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means
that this document will also change a lot.
If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented
there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out.
Finally, `from torchvision.prototype import datasets` is implied below.
## Implementation
Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin`
that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that
module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in
detail below:
```python
import pathlib
from typing import Any, BinaryIO, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype.datasets.utils import Dataset, OnlineResource
from .._api import register_dataset, register_info
NAME = "my-dataset"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
...
)
@register_dataset(NAME)
class MyDataset(Dataset):
def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None:
...
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
...
def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]:
...
def __len__(self) -> int:
...
```
In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a
dictionary of static information. The most common use case is to provide human-readable categories.
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
Finally, both the dataset class and the info function need to be registered on the API with the respective decorators.
With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively.
### `__init__(self, root, *, ..., skip_integrity_check = False)`
Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the
base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as
setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke
the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with
an underscore.
If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base
class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically
checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to
avoid missing dependencies at import time.
### `_resources(self)`
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be
build. The download will happen automatically.
Currently, the following `OnlineResource`'s are supported:
- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL.
- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`.
- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them
manually. If the file does not exist, an error will be raised with the supplied instructions.
- `KaggleDownloadResource`: Used for files that are available on Kaggle. This inherits from `ManualDownloadResource`.
Although optional in general, all resources used in the built-in datasets should comprise
[SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the
download. You can compute the checksum with system utilities e.g `sha256-sum`, or this snippet:
```python
import hashlib
def sha256sum(path, chunk_size=1024 * 1024):
checksum = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
checksum.update(chunk)
print(checksum.hexdigest())
```
### `_datapipe(self, resource_dps)`
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone
that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything
with them besides iterating.
Of course, there are some common building blocks that should suffice in 95% of the cases. The most used are:
- `Mapper`: Apply a callable to every item in the datapipe.
- `Filter`: Keep only items that satisfy a condition.
- `Demultiplexer`: Split a datapipe into multiple ones.
- `IterKeyZipper`: Merge two datapipes into one.
All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable
needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example.
`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return
value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one
of such tuples for the file specified by the resource.
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and
`Grouper`. There are two issues with that:
1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and
`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding
should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal`
and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
names (yet!).
### `__len__`
This returns an integer denoting the number of samples that can be drawn from the dataset. Please use
[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the
readability. For example, `1_281_167` vs. `1281167`.
If there are only two different numbers, a simple `if` / `else` is fine:
```py
def __len__(self):
return 12_345 if self._split == "train" else 6_789
```
If there are more options, using a dictionary usually is the most readable option:
```py
def __len__(self):
return {
"train": 3,
"val": 2,
"test": 1,
}[self._split]
```
If the number of samples depends on more than one parameter, you can use tuples as dictionary keys:
```py
def __len__(self):
return {
("train", "bar"): 4,
("train", "baz"): 3,
("test", "bar"): 2,
("test", "baz"): 1,
}[(self._split, self._foo)]
```
The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the
development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way
is to define a dummy method like
```py
def __len__(self):
return 1
```
and only fill it with the correct data if the implementation is otherwise finished.
[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples.
## Tests
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
This mock-up should resemble the original data as close as necessary, while containing only few examples.
To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the
same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function".
Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset
will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options,
you can use the `combinations_grid()` helper function, e.g.
`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`.
In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass
the `name` parameter to `@register_mock`
```py
# this is defined in torchvision/prototype/datasets/_builtin
@register_dataset("my-dataset")
class MyDataset(Dataset):
...
@register_mock(name="my-dataset", configs=...)
def my_dataset(root, config):
...
```
The mock data function receives two arguments:
- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data
needs to be placed.
- `config`: The configuration to generate the data for. This is one of the dictionaries defined in
`@register_mock(configs=...)`
The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if
the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of
the current `config`. Although this seems odd at first, this is important. Consider the following original data setup:
```
root
├── test
│ ├── test_image0.jpg
│ ...
└── train
├── train_image0.jpg
...
```
For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to
load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is
present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in
`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for
the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data.
For datasets that are ported from the old API, we already have some mock data in
[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there
and have a look at the `inject_fake_data` function. There are a few differences though:
- `tmp_dir` corresponds to `root`, but is a `str` rather than a
[`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like
`folder = pathlib.Path(tmp_dir)`. This is not needed.
- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the
new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files
specified in the dataset.
- As explained in the paragraph above, the generated data is often "incomplete" and only valid for given the config.
Make sure you follow the instructions above.
The function should return an integer indicating the number of samples in the dataset for the current `config`.
Preferably, this number should be different for different `config`'s to have more confidence in the dataset
implementation.
Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets.py -k {name}`.
## FAQ
### How do I start?
Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do
`return resources_dp[0]` to get started. Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be
instantiable via `datasets.load("mydataset")`. On a separate script, try something like
```py
from torchvision.prototype import datasets
dataset = datasets.load("mydataset")
for sample in dataset:
print(sample) # this is the content of an item in datapipe returned by _datapipe()
break
# Or you can also inspect the sample in a debugger
```
This will give you an idea of what the first datapipe in `resources_dp` contains. You can also do that with
`resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these
datapipes and return the appropriate dictionary format.
### How do I handle a dataset that defines many categories?
As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more
categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a
category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file`
function and pass it `$NAME`.
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method
should return a sequence of strings representing the category names. In the method body, you'll have to manually load
the resources, e.g.
```py
resources = self._resources()
dp = resources[0].load(self._root)
```
Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes
sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that.
To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`.
### What if a resource file forms an I/O bottleneck?
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the
`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
accepts `"decompress"` and `"extract"` to handle these common scenarios.
### How do I compute the number of samples?
Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way
than to iterate over the dataset and count the number of samples:
```py
import itertools
from torchvision.prototype import datasets
def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there
configs = combinations_grid(split=("train", "test"), foo=("bar", "baz"))
for config in configs:
dataset = datasets.load("my-dataset", **config)
num_samples = 0
for _ in dataset:
num_samples += 1
print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples)
```
To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation
files.
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco
from .country211 import Country211
from .cub200 import CUB200
from .dtd import DTD
from .eurosat import EuroSAT
from .fer2013 import FER2013
from .food101 import Food101
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .svhn import SVHN
from .usps import USPS
from .voc import VOC
import pathlib
import re
from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
read_categories_file,
read_mat,
)
from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
@register_info("caltech101")
def _caltech101_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech101"))
@register_dataset("caltech101")
class Caltech101(Dataset):
"""
- **homepage**: https://data.caltech.edu/records/20086
- **dependencies**:
- <scipy `https://scipy.org/`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech101_info()["categories"]
super().__init__(
root,
dependencies=("scipy",),
skip_integrity_check=skip_integrity_check,
)
def _resources(self) -> List[OnlineResource]:
images = GDriveResource(
"137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
file_name="101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
preprocess="decompress",
)
anns = GDriveResource(
"175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
file_name="Annotations.tar",
sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8",
)
return [images, anns]
_IMAGES_NAME_PATTERN = re.compile(r"image_(?P<id>\d+)[.]jpg")
_ANNS_NAME_PATTERN = re.compile(r"annotation_(?P<id>\d+)[.]mat")
_ANNS_CATEGORY_MAP = {
"Faces_2": "Faces",
"Faces_3": "Faces_easy",
"Motorbikes_16": "Motorbikes",
"Airplanes_Side_2": "airplanes",
}
def _is_not_background_image(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.parent.name != "BACKGROUND_Google"
def _is_ann(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return bool(self._ANNS_NAME_PATTERN.match(path.name))
def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
path = pathlib.Path(data[0])
category = path.parent.name
id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
return category, id
def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
path = pathlib.Path(data[0])
category = path.parent.name
if category in self._ANNS_CATEGORY_MAP:
category = self._ANNS_CATEGORY_MAP[category]
id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
return category, id
def _prepare_sample(
self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]]
) -> Dict[str, Any]:
key, (image_data, ann_data) = data
category, _ = key
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
image = EncodedImage.from_file(image_buffer)
ann = read_mat(ann_buffer)
return dict(
label=Label.from_category(category, categories=self._categories),
image_path=image_path,
image=image,
ann_path=ann_path,
bounding_box=BoundingBox(
ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size
),
contour=_Feature(ann["obj_contour"].T),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
anns_dp = Filter(anns_dp, self._is_ann)
dp = IterKeyZipper(
images_dp,
anns_dp,
key_fn=self._images_key_fn,
ref_key_fn=self._anns_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 8677
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
@register_info("caltech256")
def _caltech256_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech256"))
@register_dataset("caltech256")
class Caltech256(Dataset):
"""
- **homepage**: https://data.caltech.edu/records/20087
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech256_info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
GDriveResource(
"1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
file_name="256_ObjectCategories.tar",
sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e",
)
]
def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name != "RENAME2"
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 30607
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
return [name.split(".")[1] for name in sorted(dir_names)]
Faces
Faces_easy
Leopards
Motorbikes
accordion
airplanes
anchor
ant
barrel
bass
beaver
binocular
bonsai
brain
brontosaurus
buddha
butterfly
camera
cannon
car_side
ceiling_fan
cellphone
chair
chandelier
cougar_body
cougar_face
crab
crayfish
crocodile
crocodile_head
cup
dalmatian
dollar_bill
dolphin
dragonfly
electric_guitar
elephant
emu
euphonium
ewer
ferry
flamingo
flamingo_head
garfield
gerenuk
gramophone
grand_piano
hawksbill
headphone
hedgehog
helicopter
ibis
inline_skate
joshua_tree
kangaroo
ketch
lamp
laptop
llama
lobster
lotus
mandolin
mayfly
menorah
metronome
minaret
nautilus
octopus
okapi
pagoda
panda
pigeon
pizza
platypus
pyramid
revolver
rhino
rooster
saxophone
schooner
scissors
scorpion
sea_horse
snoopy
soccer_ball
stapler
starfish
stegosaurus
stop_sign
strawberry
sunflower
tick
trilobite
umbrella
watch
water_lilly
wheelchair
wild_cat
windsor_chair
wrench
yin_yang
ak47
american-flag
backpack
baseball-bat
baseball-glove
basketball-hoop
bat
bathtub
bear
beer-mug
billiards
binoculars
birdbath
blimp
bonsai-101
boom-box
bowling-ball
bowling-pin
boxing-glove
brain-101
breadmaker
buddha-101
bulldozer
butterfly
cactus
cake
calculator
camel
cannon
canoe
car-tire
cartman
cd
centipede
cereal-box
chandelier-101
chess-board
chimp
chopsticks
cockroach
coffee-mug
coffin
coin
comet
computer-keyboard
computer-monitor
computer-mouse
conch
cormorant
covered-wagon
cowboy-hat
crab-101
desk-globe
diamond-ring
dice
dog
dolphin-101
doorknob
drinking-straw
duck
dumb-bell
eiffel-tower
electric-guitar-101
elephant-101
elk
ewer-101
eyeglasses
fern
fighter-jet
fire-extinguisher
fire-hydrant
fire-truck
fireworks
flashlight
floppy-disk
football-helmet
french-horn
fried-egg
frisbee
frog
frying-pan
galaxy
gas-pump
giraffe
goat
golden-gate-bridge
goldfish
golf-ball
goose
gorilla
grand-piano-101
grapes
grasshopper
guitar-pick
hamburger
hammock
harmonica
harp
harpsichord
hawksbill-101
head-phones
helicopter-101
hibiscus
homer-simpson
horse
horseshoe-crab
hot-air-balloon
hot-dog
hot-tub
hourglass
house-fly
human-skeleton
hummingbird
ibis-101
ice-cream-cone
iguana
ipod
iris
jesus-christ
joy-stick
kangaroo-101
kayak
ketch-101
killer-whale
knife
ladder
laptop-101
lathe
leopards-101
license-plate
lightbulb
light-house
lightning
llama-101
mailbox
mandolin
mars
mattress
megaphone
menorah-101
microscope
microwave
minaret
minotaur
motorbikes-101
mountain-bike
mushroom
mussels
necktie
octopus
ostrich
owl
palm-pilot
palm-tree
paperclip
paper-shredder
pci-card
penguin
people
pez-dispenser
photocopier
picnic-table
playing-card
porcupine
pram
praying-mantis
pyramid
raccoon
radio-telescope
rainbow
refrigerator
revolver-101
rifle
rotary-phone
roulette-wheel
saddle
saturn
school-bus
scorpion-101
screwdriver
segway
self-propelled-lawn-mower
sextant
sheet-music
skateboard
skunk
skyscraper
smokestack
snail
snake
sneaker
snowmobile
soccer-ball
socks
soda-can
spaghetti
speed-boat
spider
spoon
stained-glass
starfish-101
steering-wheel
stirrups
sunflower-101
superman
sushi
swan
swiss-army-knife
sword
syringe
tambourine
teapot
teddy-bear
teepee
telephone-box
tennis-ball
tennis-court
tennis-racket
theodolite
toaster
tomato
tombstone
top-hat
touring-bike
tower-pisa
traffic-light
treadmill
triceratops
tricycle
trilobite-101
tripod
t-shirt
tuning-fork
tweezer
umbrella-101
unicorn
vcr
video-projector
washing-machine
watch-101
waterfall
watermelon
welding-mask
wheelbarrow
windmill
wine-bottle
xylophone
yarmulke
yo-yo
zebra
airplanes-101
car-side-101
faces-easy-101
greyhound
tennis-shoes
toad
clutter
import csv
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
)
from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__(
self,
datapipe: IterDataPipe[Tuple[Any, BinaryIO]],
*,
fieldnames: Optional[Sequence[str]] = None,
) -> None:
self.datapipe = datapipe
self.fieldnames = fieldnames
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
file = (line.decode() for line in file)
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(file)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
NAME = "celeba"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CelebA(Dataset):
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
splits = GDriveResource(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
file_name="list_eval_partition.txt",
)
images = GDriveResource(
"0B7EVK8r0v71pZjFTYXZWM3FlRnM",
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
file_name="img_align_celeba.zip",
)
identities = GDriveResource(
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
file_name="identity_CelebA.txt",
)
attributes = GDriveResource(
"0B7EVK8r0v71pblRyaVFSWGxPY0U",
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt",
)
bounding_boxes = GDriveResource(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt",
)
landmarks = GDriveResource(
"0B7EVK8r0v71pd0FJY3Blby1HUTQ",
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt",
)
return [splits, images, identities, attributes, bounding_boxes, landmarks]
def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool:
split_id = {
"train": "0",
"val": "1",
"test": "2",
}[self._split]
return data[1]["split_id"] == split_id
def _prepare_sample(
self,
data: Tuple[
Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]],
Tuple[
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
],
],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, (_, image_data) = split_and_image_data
path, buffer = image_data
image = EncodedImage.from_file(buffer)
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data
return dict(
path=path,
image=image,
identity=Label(int(identity["identity"])),
attributes={attr: value == "1" for attr, value in attributes.items()},
bounding_box=BoundingBox(
[int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")],
format="xywh",
image_size=image.image_size,
),
landmarks={
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
anns_dp = Zipper(
*[
CelebACSVParser(dp, fieldnames=fieldnames)
for dp, fieldnames in (
(identities_dp, ("image_id", "identity")),
(attributes_dp, None),
(bounding_boxes_dp, None),
(landmarks_dp, None),
)
]
)
dp = IterKeyZipper(
splits_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
dp = IterKeyZipper(
dp,
anns_dp,
key_fn=getitem(0),
ref_key_fn=getitem(0, 0),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 162_770,
"val": 19_867,
"test": 19_962,
}[self._split]
import abc
import io
import pathlib
import pickle
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
)
from torchvision.prototype.features import Image, Label
from .._api import register_dataset, register_info
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
self.datapipe = datapipe
self.labels_key = labels_key
def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
for mapping in self.datapipe:
image_arrays = mapping["data"].reshape((-1, 3, 32, 32))
category_idcs = mapping[self.labels_key]
yield from iter(zip(image_arrays, category_idcs))
class _CifarBase(Dataset):
_FILE_NAME: str
_SHA256: str
_LABELS_KEY: str
_META_FILE_NAME: str
_CATEGORIES_KEY: str
_categories: List[str]
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
pass
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
sha256=self._SHA256,
)
]
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
return dict(
image=Image(image_array),
label=Label(category_idx, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_data_file)
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 50_000 if self._split == "train" else 10_000
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
@register_info("cifar10")
def _cifar10_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar10"))
@register_dataset("cifar10")
class Cifar10(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-10-python.tar.gz"
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
_categories = _cifar10_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if self._split == "train" else "test")
@register_info("cifar100")
def _cifar100_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar100"))
@register_dataset("cifar100")
class Cifar100(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-100-python.tar.gz"
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
_categories = _cifar100_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == self._split
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
apple
aquarium_fish
baby
bear
beaver
bed
bee
beetle
bicycle
bottle
bowl
boy
bridge
bus
butterfly
camel
can
castle
caterpillar
cattle
chair
chimpanzee
clock
cloud
cockroach
couch
crab
crocodile
cup
dinosaur
dolphin
elephant
flatfish
forest
fox
girl
hamster
house
kangaroo
keyboard
lamp
lawn_mower
leopard
lion
lizard
lobster
man
maple_tree
motorcycle
mountain
mouse
mushroom
oak_tree
orange
orchid
otter
palm_tree
pear
pickup_truck
pine_tree
plain
plate
poppy
porcupine
possum
rabbit
raccoon
ray
road
rocket
rose
sea
seal
shark
shrew
skunk
skyscraper
snail
snake
spider
squirrel
streetcar
sunflower
sweet_pepper
table
tank
telephone
television
tiger
tractor
train
trout
tulip
turtle
wardrobe
whale
willow_tree
wolf
woman
worm
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
)
from torchvision.prototype.features import EncodedImage, Label
from .._api import register_dataset, register_info
NAME = "clevr"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CLEVR(Dataset):
"""
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
)
return [archive]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parent.name == "scenes":
return 1
else:
return None
def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool:
key, _ = data
return key == "scenes"
def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]:
return data, None
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
images_dp = Filter(images_dp, path_comparator("parent.name", self._split))
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
if self._split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp)
dp = IterKeyZipper(
images_dp,
scenes_dp,
key_fn=path_accessor("name"),
ref_key_fn=getitem("image_filename"),
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 70_000 if self._split == "train" else 15_000
__background__,N/A
person,person
bicycle,vehicle
car,vehicle
motorcycle,vehicle
airplane,vehicle
bus,vehicle
train,vehicle
truck,vehicle
boat,vehicle
traffic light,outdoor
fire hydrant,outdoor
N/A,N/A
stop sign,outdoor
parking meter,outdoor
bench,outdoor
bird,animal
cat,animal
dog,animal
horse,animal
sheep,animal
cow,animal
elephant,animal
bear,animal
zebra,animal
giraffe,animal
N/A,N/A
backpack,accessory
umbrella,accessory
N/A,N/A
N/A,N/A
handbag,accessory
tie,accessory
suitcase,accessory
frisbee,sports
skis,sports
snowboard,sports
sports ball,sports
kite,sports
baseball bat,sports
baseball glove,sports
skateboard,sports
surfboard,sports
tennis racket,sports
bottle,kitchen
N/A,N/A
wine glass,kitchen
cup,kitchen
fork,kitchen
knife,kitchen
spoon,kitchen
bowl,kitchen
banana,food
apple,food
sandwich,food
orange,food
broccoli,food
carrot,food
hot dog,food
pizza,food
donut,food
cake,food
chair,furniture
couch,furniture
potted plant,furniture
bed,furniture
N/A,N/A
dining table,furniture
N/A,N/A
N/A,N/A
toilet,furniture
N/A,N/A
tv,electronic
laptop,electronic
mouse,electronic
remote,electronic
keyboard,electronic
cell phone,electronic
microwave,appliance
oven,appliance
toaster,appliance
sink,appliance
refrigerator,appliance
N/A,N/A
book,indoor
clock,indoor
vase,indoor
scissors,indoor
teddy bear,indoor
hair drier,indoor
toothbrush,indoor
import pathlib
import re
from collections import defaultdict, OrderedDict
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
import torch
from torchdata.datapipes.iter import (
Demultiplexer,
Filter,
Grouper,
IterDataPipe,
IterKeyZipper,
JsonParser,
Mapper,
UnBatcher,
)
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
MappingIterator,
path_accessor,
read_categories_file,
)
from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
NAME = "coco"
@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*read_categories_file(NAME))
return dict(categories=categories, super_categories=super_categories)
@register_dataset(NAME)
class Coco(Dataset):
"""
- **homepage**: https://cocodataset.org/
- **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2017",
annotations: Optional[str] = "instances",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val"})
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
self._annotations = (
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
if annotations is not None
else None
)
info = _info()
categories, super_categories = info["categories"], info["super_categories"]
self._categories = categories
self._category_to_super_category = dict(zip(categories, super_categories))
super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)
_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"
_IMAGES_CHECKSUMS = {
("2014", "train"): "ede4087e640bddba550e090eae701092534b554b42b05ac33f0300b984b31775",
("2014", "val"): "fe9be816052049c34717e077d9e34aa60814a55679f804cd043e3cbee3b9fde0",
("2017", "train"): "69a8bb58ea5f8f99d24875f21416de2e9ded3178e903f1f7603e283b9e06d929",
("2017", "val"): "4f7e2ccb2866ec5041993c9cf2a952bbed69647b115d0f74da7ce8f4bef82f05",
}
_META_URL_BASE = "http://images.cocodataset.org/annotations"
_META_CHECKSUMS = {
"2014": "031296bbc80c45a1d1f76bf9a90ead27e94e99ec629208449507a4917a3bf009",
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
}
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
)
meta = HttpResource(
f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
sha256=self._META_CHECKSUMS[self._year],
)
return [images, meta]
def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size: Tuple[int, int]) -> torch.Tensor:
from pycocotools import mask
if is_crowd:
segmentation = mask.frPyObjects(segmentation, *image_size)
else:
segmentation = mask.merge(mask.frPyObjects(segmentation, *image_size))
return torch.from_numpy(mask.decode(segmentation)).to(torch.bool)
def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
image_size = (image_meta["height"], image_meta["width"])
labels = [ann["category_id"] for ann in anns]
return dict(
# TODO: create a segmentation feature
segmentations=_Feature(
torch.stack(
[
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
for ann in anns
]
)
),
areas=_Feature([ann["area"] for ann in anns]),
crowds=_Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
image_size=image_size,
),
labels=Label(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
ann_ids=[ann["id"] for ann in anns],
)
def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
return dict(
captions=[ann["caption"] for ann in anns],
ann_ids=[ann["id"] for ann in anns],
)
_ANN_DECODERS = OrderedDict(
[
("instances", _decode_instances_anns),
("captions", _decode_captions_ann),
]
)
_META_FILE_PATTERN = re.compile(
rf"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)
def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(
match
and match["split"] == self._split
and match["year"] == self._year
and match["annotations"] == self._annotations
)
def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
if key == "images":
return 0
elif key == "annotations":
return 1
else:
return None
def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
)
def _prepare_sample(
self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
ann_data, image_data = data
anns, image_meta = ann_data
sample = self._prepare_image(image_data)
# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps
if self._annotations is None:
dp = hint_shuffling(images_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_image)
meta_dp = Filter(meta_dp, self._filter_meta_files)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
images_meta_dp, anns_meta_dp = Demultiplexer(
meta_dp,
2,
self._classify_meta,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
images_meta_dp = Mapper(images_meta_dp, getitem(1))
images_meta_dp = UnBatcher(images_meta_dp)
anns_meta_dp = Mapper(anns_meta_dp, getitem(1))
anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_shuffling(anns_meta_dp)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_dp = IterKeyZipper(
anns_meta_dp,
images_meta_dp,
key_fn=getitem(0, "image_id"),
ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
anns_dp,
images_dp,
key_fn=getitem(1, "file_name"),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
}[(self._split, self._year)][
self._annotations # type: ignore[index]
]
def _generate_categories(self) -> Tuple[Tuple[str, str]]:
self._annotations = "instances"
resources = self._resources()
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_meta_files)
dp = JsonParser(dp)
_, meta = next(iter(dp))
# List[Tuple[super_category, id, category]]
label_data = [cast(Tuple[str, int, str], tuple(info.values())) for info in meta["categories"]]
# COCO actually defines 91 categories, but only 80 of them have instances. Still, the category_id refers to the
# full set. To keep the labels dense, we fill the gaps with N/A. Note that there are only 10 gaps, so the total
# number of categories is 90 rather than 91.
_, ids, _ = zip(*label_data)
missing_ids = set(range(1, max(ids) + 1)) - set(ids)
label_data.extend([("N/A", id, "N/A") for id in missing_ids])
# We also add a background category to be used during segmentation.
label_data.append(("N/A", 0, "__background__"))
super_categories, _, categories = zip(*sorted(label_data, key=lambda info: info[1]))
return cast(Tuple[Tuple[str, str]], tuple(zip(categories, super_categories)))
AD
AE
AF
AG
AI
AL
AM
AO
AQ
AR
AT
AU
AW
AX
AZ
BA
BB
BD
BE
BF
BG
BH
BJ
BM
BN
BO
BQ
BR
BS
BT
BW
BY
BZ
CA
CD
CF
CH
CI
CK
CL
CM
CN
CO
CR
CU
CV
CW
CY
CZ
DE
DK
DM
DO
DZ
EC
EE
EG
ES
ET
FI
FJ
FK
FO
FR
GA
GB
GD
GE
GF
GG
GH
GI
GL
GM
GP
GR
GS
GT
GU
GY
HK
HN
HR
HT
HU
ID
IE
IL
IM
IN
IQ
IR
IS
IT
JE
JM
JO
JP
KE
KG
KH
KN
KP
KR
KW
KY
KZ
LA
LB
LC
LI
LK
LR
LT
LU
LV
LY
MA
MC
MD
ME
MF
MG
MK
ML
MM
MN
MO
MQ
MR
MT
MU
MV
MW
MX
MY
MZ
NA
NC
NG
NI
NL
NO
NP
NZ
OM
PA
PE
PF
PG
PH
PK
PL
PR
PS
PT
PW
PY
QA
RE
RO
RS
RU
RW
SA
SB
SC
SD
SE
SG
SH
SI
SJ
SK
SL
SM
SN
SO
SS
SV
SX
SY
SZ
TG
TH
TJ
TL
TM
TN
TO
TR
TT
TW
TZ
UA
UG
US
UY
UZ
VA
VE
VG
VI
VN
VU
WS
XK
YE
ZA
ZM
ZW
import pathlib
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
)
from torchvision.prototype.features import EncodedImage, Label
from .._api import register_dataset, register_info
NAME = "country211"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class Country211(Dataset):
"""
- **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
self._split_folder_name = "valid" if split == "val" else split
self._categories = _info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
"https://openaipublic.azureedge.net/clip/data/country211.tgz",
sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c",
)
]
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self._categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
return pathlib.Path(data[0]).parent.parent.name == split
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name))
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 31_650,
"val": 10_550,
"test": 21_100,
}[self._split]
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
Black_footed_Albatross
Laysan_Albatross
Sooty_Albatross
Groove_billed_Ani
Crested_Auklet
Least_Auklet
Parakeet_Auklet
Rhinoceros_Auklet
Brewer_Blackbird
Red_winged_Blackbird
Rusty_Blackbird
Yellow_headed_Blackbird
Bobolink
Indigo_Bunting
Lazuli_Bunting
Painted_Bunting
Cardinal
Spotted_Catbird
Gray_Catbird
Yellow_breasted_Chat
Eastern_Towhee
Chuck_will_Widow
Brandt_Cormorant
Red_faced_Cormorant
Pelagic_Cormorant
Bronzed_Cowbird
Shiny_Cowbird
Brown_Creeper
American_Crow
Fish_Crow
Black_billed_Cuckoo
Mangrove_Cuckoo
Yellow_billed_Cuckoo
Gray_crowned_Rosy_Finch
Purple_Finch
Northern_Flicker
Acadian_Flycatcher
Great_Crested_Flycatcher
Least_Flycatcher
Olive_sided_Flycatcher
Scissor_tailed_Flycatcher
Vermilion_Flycatcher
Yellow_bellied_Flycatcher
Frigatebird
Northern_Fulmar
Gadwall
American_Goldfinch
European_Goldfinch
Boat_tailed_Grackle
Eared_Grebe
Horned_Grebe
Pied_billed_Grebe
Western_Grebe
Blue_Grosbeak
Evening_Grosbeak
Pine_Grosbeak
Rose_breasted_Grosbeak
Pigeon_Guillemot
California_Gull
Glaucous_winged_Gull
Heermann_Gull
Herring_Gull
Ivory_Gull
Ring_billed_Gull
Slaty_backed_Gull
Western_Gull
Anna_Hummingbird
Ruby_throated_Hummingbird
Rufous_Hummingbird
Green_Violetear
Long_tailed_Jaeger
Pomarine_Jaeger
Blue_Jay
Florida_Jay
Green_Jay
Dark_eyed_Junco
Tropical_Kingbird
Gray_Kingbird
Belted_Kingfisher
Green_Kingfisher
Pied_Kingfisher
Ringed_Kingfisher
White_breasted_Kingfisher
Red_legged_Kittiwake
Horned_Lark
Pacific_Loon
Mallard
Western_Meadowlark
Hooded_Merganser
Red_breasted_Merganser
Mockingbird
Nighthawk
Clark_Nutcracker
White_breasted_Nuthatch
Baltimore_Oriole
Hooded_Oriole
Orchard_Oriole
Scott_Oriole
Ovenbird
Brown_Pelican
White_Pelican
Western_Wood_Pewee
Sayornis
American_Pipit
Whip_poor_Will
Horned_Puffin
Common_Raven
White_necked_Raven
American_Redstart
Geococcyx
Loggerhead_Shrike
Great_Grey_Shrike
Baird_Sparrow
Black_throated_Sparrow
Brewer_Sparrow
Chipping_Sparrow
Clay_colored_Sparrow
House_Sparrow
Field_Sparrow
Fox_Sparrow
Grasshopper_Sparrow
Harris_Sparrow
Henslow_Sparrow
Le_Conte_Sparrow
Lincoln_Sparrow
Nelson_Sharp_tailed_Sparrow
Savannah_Sparrow
Seaside_Sparrow
Song_Sparrow
Tree_Sparrow
Vesper_Sparrow
White_crowned_Sparrow
White_throated_Sparrow
Cape_Glossy_Starling
Bank_Swallow
Barn_Swallow
Cliff_Swallow
Tree_Swallow
Scarlet_Tanager
Summer_Tanager
Artic_Tern
Black_Tern
Caspian_Tern
Common_Tern
Elegant_Tern
Forsters_Tern
Least_Tern
Green_tailed_Towhee
Brown_Thrasher
Sage_Thrasher
Black_capped_Vireo
Blue_headed_Vireo
Philadelphia_Vireo
Red_eyed_Vireo
Warbling_Vireo
White_eyed_Vireo
Yellow_throated_Vireo
Bay_breasted_Warbler
Black_and_white_Warbler
Black_throated_Blue_Warbler
Blue_winged_Warbler
Canada_Warbler
Cape_May_Warbler
Cerulean_Warbler
Chestnut_sided_Warbler
Golden_winged_Warbler
Hooded_Warbler
Kentucky_Warbler
Magnolia_Warbler
Mourning_Warbler
Myrtle_Warbler
Nashville_Warbler
Orange_crowned_Warbler
Palm_Warbler
Pine_Warbler
Prairie_Warbler
Prothonotary_Warbler
Swainson_Warbler
Tennessee_Warbler
Wilson_Warbler
Worm_eating_Warbler
Yellow_Warbler
Northern_Waterthrush
Louisiana_Waterthrush
Bohemian_Waxwing
Cedar_Waxwing
American_Three_toed_Woodpecker
Pileated_Woodpecker
Red_bellied_Woodpecker
Red_cockaded_Woodpecker
Red_headed_Woodpecker
Downy_Woodpecker
Bewick_Wren
Cactus_Wren
Carolina_Wren
House_Wren
Marsh_Wren
Rock_Wren
Winter_Wren
Common_Yellowthroat
import csv
import functools
import pathlib
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import (
CSVDictParser,
CSVParser,
Demultiplexer,
Filter,
IterDataPipe,
IterKeyZipper,
LineReader,
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
read_categories_file,
read_mat,
)
from torchvision.prototype.features import _Feature, BoundingBox, EncodedImage, Label
from .._api import register_dataset, register_info
csv.register_dialect("cub200", delimiter=" ")
NAME = "cub200"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=read_categories_file(NAME))
@register_dataset(NAME)
class CUB200(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2011",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
self._year = self._verify_str_arg(year, "year", ("2010", "2011"))
self._categories = _info()["categories"]
super().__init__(
root,
# TODO: this will only be available after https://github.com/pytorch/vision/pull/5473
# dependencies=("scipy",),
skip_integrity_check=skip_integrity_check,
)
def _resources(self) -> List[OnlineResource]:
if self._year == "2011":
archive = GDriveResource(
"1hbzc_P1FuxMkcabkgn9ZKinBwW683j45",
file_name="CUB_200_2011.tgz",
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
preprocess="decompress",
)
segmentations = GDriveResource(
"1EamOKGLoTuZdtcVYbHMWNpkn3iAVj8TP",
file_name="segmentations.tgz",
sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f",
preprocess="decompress",
)
return [archive, segmentations]
else: # self._year == "2010"
split = GDriveResource(
"1vZuZPqha0JjmwkdaS_XtYryE3Jf5Q1AC",
file_name="lists.tgz",
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
preprocess="decompress",
)
images = GDriveResource(
"1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx",
file_name="images.tgz",
sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e",
preprocess="decompress",
)
anns = GDriveResource(
"16NsbTpMs5L6hT4hUJAmpW2u7wH326WTR",
file_name="annotations.tgz",
sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1",
preprocess="decompress",
)
return [split, images, anns]
def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.name == "train_test_split.txt":
return 1
elif path.name == "images.txt":
return 2
elif path.name == "bounding_boxes.txt":
return 3
else:
return None
def _2011_extract_file_name(self, rel_posix_path: str) -> str:
return rel_posix_path.rsplit("/", maxsplit=1)[1]
def _2011_filter_split(self, row: List[str]) -> bool:
_, split_id = row
return {
"0": "test",
"1": "train",
}[split_id] == self._split
def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str:
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name
def _2011_prepare_ann(
self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int]
) -> Dict[str, Any]:
_, (bounding_box_data, segmentation_data) = data
segmentation_path, segmentation_buffer = segmentation_data
return dict(
bounding_box=BoundingBox(
[float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size
),
segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer),
)
def _2010_split_key(self, data: str) -> str:
return data.rsplit("/", maxsplit=1)[1]
def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name, data
def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]:
_, (path, buffer) = data
content = read_mat(buffer)
return dict(
ann_path=path,
bounding_box=BoundingBox(
[int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")],
format="xyxy",
image_size=image_size,
),
segmentation=_Feature(content["seg"]),
)
def _prepare_sample(
self,
data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any],
*,
prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]],
) -> Dict[str, Any]:
data, anns_data = data
_, image_data = data
path, buffer = image_data
image = EncodedImage.from_file(buffer)
return dict(
prepare_ann_fn(anns_data, image.image_size),
image=image,
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
prepare_ann_fn: Callable
if self._year == "2011":
archive_dp, segmentations_dp = resource_dps
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
image_files_dp = CSVParser(image_files_dp, dialect="cub200")
image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1)
image_files_map = IterToMapConverter(image_files_dp)
split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.__getitem__)
bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)
anns_dp = IterKeyZipper(
bounding_boxes_dp,
segmentations_dp,
key_fn=getitem(0),
ref_key_fn=self._2011_segmentation_key,
keep_key=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
prepare_ann_fn = self._2011_prepare_ann
else: # self._year == "2010"
split_dp, images_dp, anns_dp = resource_dps
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = Mapper(split_dp, self._2010_split_key)
anns_dp = Mapper(anns_dp, self._2010_anns_key)
prepare_ann_fn = self._2010_prepare_ann
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
dp = IterKeyZipper(
split_dp,
images_dp,
getitem(),
path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = IterKeyZipper(
dp,
anns_dp,
getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn))
def __len__(self) -> int:
return {
("train", "2010"): 3_000,
("test", "2010"): 3_033,
("train", "2011"): 5_994,
("test", "2011"): 5_794,
}[(self._split, self._year)]
def _generate_categories(self) -> List[str]:
self._year = "2011"
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200")
return [row["category"].split(".")[1] for row in dp]
banded
blotchy
braided
bubbly
bumpy
chequered
cobwebbed
cracked
crosshatched
crystalline
dotted
fibrous
flecked
freckled
frilly
gauzy
grid
grooved
honeycombed
interlaced
knitted
lacelike
lined
marbled
matted
meshed
paisley
perforated
pitted
pleated
polka-dotted
porous
potholed
scaly
smeared
spiralled
sprinkled
stained
stratified
striped
studded
swirly
veined
waffled
woven
wrinkled
zigzagged
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment