"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "4e438ba19456fa6230b5c0163e7ad8f01e7e8e1f"
Unverified Commit fbb4cc54 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove torchvision.prototype module and related tests / CI from release branch (#7983)

parent a90e5846
name: Prototype tests on Linux
on:
pull_request:
jobs:
unittests-prototype:
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
runner: ["linux.12xlarge"]
gpu-arch-type: ["cpu"]
include:
- python-version: "3.8"
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "11.8"
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@release/2.1
with:
repository: pytorch/vision
runner: ${{ matrix.runner }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
timeout: 120
script: |
set -euo pipefail
export PYTHON_VERSION=${{ matrix.python-version }}
export GPU_ARCH_TYPE=${{ matrix.gpu-arch-type }}
export GPU_ARCH_VERSION=${{ matrix.gpu-arch-version }}
./.github/scripts/setup-env.sh
# Prepare conda
CONDA_PATH=$(which conda)
eval "$(${CONDA_PATH} shell.bash hook)"
conda activate ci
echo '::group::Install testing utilities'
pip install --progress-bar=off pytest pytest-mock pytest-cov
echo '::endgroup::'
# We don't want to run the prototype datasets tests. Since the positional glob into `pytest`, i.e.
# `test/test_prototype*.py` takes the highest priority, neither `--ignore` nor `--ignore-glob` can help us here.
rm test/test_prototype_datasets*.py
pytest \
-v --durations=25 \
--cov=torchvision/prototype --cov-report=term-missing \
--junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" \
test/test_prototype_*.py
import io
import pickle
from collections import deque
from pathlib import Path
import pytest
import torch
import torchvision.transforms.v2 as transforms
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
# TODO: replace with torchdata.dataloader2.DataLoader2 as soon as it is stable-ish
from torch.utils.data import DataLoader
# TODO: replace with torchdata equivalent as soon as it is available
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision import tv_tensors
from torchvision._utils import sequence_to_str
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.tv_tensors import Label
from torchvision.transforms.v2._utils import is_pure_tensor
def assert_samples_equal(*args, msg=None, **kwargs):
error_metas = not_close_error_metas(
*args, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True, **kwargs
)
if error_metas:
raise error_metas[0].to_error(msg)
def extract_datapipes(dp):
return get_all_graph_pipes(traverse_dps(dp))
def consume(iterator):
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
deque(iterator, maxlen=0)
def next_consume(iterator):
item = next(iterator)
consume(iterator)
return item
@pytest.fixture(autouse=True)
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path))
yield tmp_path
def test_coverage():
untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys()
if untested_datasets:
raise AssertionError(
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. "
f"Please add mock data to `test/builtin_dataset_mocks.py`."
)
@pytest.mark.filterwarnings("error")
class TestCommon:
@pytest.mark.parametrize("name", datasets.list_datasets())
def test_info(self, name):
try:
info = datasets.info(name)
except ValueError:
raise AssertionError("No info available.") from None
if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())):
raise AssertionError("Info should be a dictionary with string keys.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
try:
sample = next_consume(iter(dataset))
except StopIteration:
raise AssertionError("Unable to draw any sample.") from None
except Exception as error:
raise AssertionError("Drawing a sample raised the error above.") from error
if not isinstance(sample, dict):
raise AssertionError(f"Samples should be dictionaries, but got {type(sample)} instead.")
if not sample:
raise AssertionError("Sample dictionary is empty.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config)
assert len(list(dataset)) == mock_info["num_samples"]
@pytest.fixture
def log_session_streams(self):
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams
try:
StreamWrapper.debug_unclosed_streams = True
yield
finally:
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_stream_closing(self, log_session_streams, dataset_mock, config):
def make_msg_and_close(head):
unclosed_streams = []
for stream in list(StreamWrapper.session_streams.keys()):
unclosed_streams.append(repr(stream.file_obj))
stream.close()
unclosed_streams = "\n".join(unclosed_streams)
return f"{head}\n\n{unclosed_streams}"
if StreamWrapper.session_streams:
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:"))
dataset, _ = dataset_mock.load(config)
consume(iter(dataset))
if StreamWrapper.session_streams:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_unaccompanied_pure_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}
if pure_tensors and not any(
isinstance(item, (tv_tensors.Image, tv_tensors.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(pure_tensors), separate_last='and ')} contained pure tensors, "
f"but didn't find any (encoded) image or video."
)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
dataset = dataset.map(transforms.Identity())
consume(iter(dataset))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
traverse_dps(dataset)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
pickle.dumps(dataset)
# This has to be a proper function, since lambda's or local functions
# cannot be pickled, but this is a requirement for the DataLoader with
# multiprocessing, i.e. num_workers > 0
def _collate_fn(self, batch):
return batch
@pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, dataset_mock, config, num_workers):
dataset, _ = dataset_mock.load(config)
dl = DataLoader(
dataset,
batch_size=2,
num_workers=num_workers,
collate_fn=self._collate_fn,
)
consume(dl)
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
dataset, _ = dataset_mock.load(config)
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_save_load(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
with io.BytesIO() as buffer:
torch.save(sample, buffer)
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
for dp in extract_datapipes(dataset):
if hasattr(dp, "buffer_size"):
# TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is
# resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
assert len(dataset) > 0
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
for key, type in (
("nist_hsf_series", int),
("nist_writer_id", int),
("digit_index", int),
("nist_label", int),
("global_digit_index", int),
("duplicate", bool),
("unused", bool),
):
assert key in sample and isinstance(sample[key], type)
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
class TestGTSRB:
def test_label_matches_path(self, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config["split"] != "train":
return
dataset, _ = dataset_mock.load(config)
for sample in dataset:
label_from_path = int(Path(sample["path"]).parent.name)
assert sample["label"] == label_from_path
@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
class TestUSPS:
def test_sample_content(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
for sample in dataset:
assert "image" in sample
assert "label" in sample
assert isinstance(sample["image"], tv_tensors.Image)
assert isinstance(sample["label"], Label)
assert sample["image"].shape == (1, 16, 16)
import gzip
import pathlib
import sys
import numpy as np
import pytest
import torch
from datasets_utils import make_fake_flo_file, make_tar
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.datasets.utils import _decompress
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import fromfile, read_flo
@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning")
@pytest.mark.parametrize(
("np_dtype", "torch_dtype", "byte_order"),
[
(">f4", torch.float32, "big"),
("<f8", torch.float64, "little"),
("<i4", torch.int32, "little"),
(">i8", torch.int64, "big"),
("|u1", torch.uint8, sys.byteorder),
],
)
@pytest.mark.parametrize("count", (-1, 2))
@pytest.mark.parametrize("mode", ("rb", "r+b"))
def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode):
path = tmpdir / "data.bin"
rng = np.random.RandomState(0)
rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path)
for count_ in (-1, count // 2):
expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:]))
with open(path, mode) as file:
actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_)
torch.testing.assert_close(actual, expected)
def test_read_flo(tmpdir):
path = tmpdir / "test.flo"
make_fake_flo_file(3, 4, path)
with open(path, "rb") as file:
actual = read_flo(file)
expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))
torch.testing.assert_close(actual, expected)
class TestOnlineResource:
class DummyResource(OnlineResource):
def __init__(self, download_fn=None, **kwargs):
super().__init__(**kwargs)
self._download_fn = download_fn
def _download(self, root):
if self._download_fn is None:
raise pytest.UsageError(
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`."
)
return self._download_fn(self, root)
def _make_file(self, root, *, content, name="file.txt"):
file = root / name
with open(file, "w") as fh:
fh.write(content)
return file
def _make_folder(self, root, *, name="folder"):
folder = root / name
subfolder = folder / "subfolder"
subfolder.mkdir(parents=True)
files = {}
for idx, root in enumerate([folder, folder, subfolder]):
content = f"sentinel{idx}"
file = self._make_file(root, name=f"file{idx}.txt", content=content)
files[str(file)] = content
return folder, files
def _make_tar(self, root, *, name="archive.tar", remove=True):
folder, files = self._make_folder(root, name=name.split(".")[0])
archive = make_tar(root, name, folder, remove=remove)
files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()}
return archive, files
def test_load_file(self, tmp_path):
content = "sentinel"
file = self._make_file(tmp_path, content=content)
resource = self.DummyResource(file_name=file.name)
dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
data = list(dp)
assert len(data) == 1
path, buffer = data[0]
assert path == str(file)
assert buffer.read().decode() == content
def test_load_folder(self, tmp_path):
folder, files = self._make_folder(tmp_path)
resource = self.DummyResource(file_name=folder.name)
dp = resource.load(tmp_path)
assert isinstance(dp, FileOpener)
assert {path: buffer.read().decode() for path, buffer in dp} == files
def test_load_archive(self, tmp_path):
archive, files = self._make_tar(tmp_path)
resource = self.DummyResource(file_name=archive.name)
dp = resource.load(tmp_path)
assert isinstance(dp, TarArchiveLoader)
assert {path: buffer.read().decode() for path, buffer in dp} == files
def test_priority_decompressed_gt_raw(self, tmp_path):
# We don't need to actually compress here. Adding the suffix is sufficient
self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz")
file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt")
resource = self.DummyResource(file_name=file.name)
dp = resource.load(tmp_path)
path, buffer = next(iter(dp))
assert path == str(file)
assert buffer.read().decode() == "decompressed_sentinel"
def test_priority_extracted_gt_decompressed(self, tmp_path):
archive, _ = self._make_tar(tmp_path, remove=False)
resource = self.DummyResource(file_name=archive.name)
dp = resource.load(tmp_path)
# If the archive had been selected, this would be a `TarArchiveReader`
assert isinstance(dp, FileOpener)
def test_download(self, tmp_path):
download_fn_was_called = False
def download_fn(resource, root):
nonlocal download_fn_was_called
download_fn_was_called = True
return self._make_file(root, content="_", name=resource.file_name)
resource = self.DummyResource(
file_name="file.txt",
download_fn=download_fn,
)
resource.load(tmp_path)
assert download_fn_was_called, "`download_fn()` was never called"
# This tests the `"decompress"` literal as well as a custom callable
@pytest.mark.parametrize(
"preprocess",
[
"decompress",
lambda path: _decompress(str(path), remove_finished=True),
],
)
def test_preprocess_decompress(self, tmp_path, preprocess):
file_name = "file.txt.gz"
content = "sentinel"
def download_fn(resource, root):
file = root / resource.file_name
with gzip.open(file, "wb") as fh:
fh.write(content.encode())
return file
resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn)
dp = resource.load(tmp_path)
data = list(dp)
assert len(data) == 1
path, buffer = data[0]
assert path == str(tmp_path / file_name).replace(".gz", "")
assert buffer.read().decode() == content
def test_preprocess_extract(self, tmp_path):
files = None
def download_fn(resource, root):
nonlocal files
archive, files = self._make_tar(root, name=resource.file_name)
return archive
resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn)
dp = resource.load(tmp_path)
assert files is not None, "`download_fn()` was never called"
assert isinstance(dp, FileOpener)
actual = {path: buffer.read().decode() for path, buffer in dp}
expected = {
path.replace(resource.file_name, resource.file_name.split(".")[0]): content
for path, content in files.items()
}
assert actual == expected
def test_preprocess_only_after_download(self, tmp_path):
file = self._make_file(tmp_path, content="_")
def preprocess(path):
raise AssertionError("`preprocess` was called although the file was already present.")
resource = self.DummyResource(
file_name=file.name,
preprocess=preprocess,
)
resource.load(tmp_path)
class TestHttpResource:
def test_resolve_to_http(self, mocker):
file_name = "data.tar"
original_url = f"http://downloads.pytorch.org/{file_name}"
redirected_url = original_url.replace("http", "https")
sha256_sentinel = "sha256_sentinel"
def preprocess_sentinel(path):
return path
original_resource = HttpResource(
original_url,
sha256=sha256_sentinel,
preprocess=preprocess_sentinel,
)
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
redirected_resource = original_resource.resolve()
assert isinstance(redirected_resource, HttpResource)
assert redirected_resource.url == redirected_url
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
def test_resolve_to_gdrive(self, mocker):
file_name = "data.tar"
original_url = f"http://downloads.pytorch.org/{file_name}"
id_sentinel = "id-sentinel"
redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view"
sha256_sentinel = "sha256_sentinel"
def preprocess_sentinel(path):
return path
original_resource = HttpResource(
original_url,
sha256=sha256_sentinel,
preprocess=preprocess_sentinel,
)
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
redirected_resource = original_resource.resolve()
assert isinstance(redirected_resource, GDriveResource)
assert redirected_resource.id == id_sentinel
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel
def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))
def _resources(self):
pass
def _datapipe(self, resource_dps):
pass
def __len__(self):
pass
with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
import pytest
import test_models as TM
import torch
from common_utils import cpu_and_cuda, set_rng_seed
from torchvision.prototype import models
@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_cuda())
def test_raft_stereo(model_fn, model_mode, dev):
# A simple test to make sure the model can do forward pass and jit scriptable
set_rng_seed(0)
# Use corr_pyramid and corr_block with smaller num_levels and radius to prevent nan output
# get the idea from test_models.test_raft
corr_pyramid = models.depth.stereo.raft_stereo.CorrPyramid1d(num_levels=2)
corr_block = models.depth.stereo.raft_stereo.CorrBlock1d(num_levels=2, radius=2)
model = model_fn(corr_pyramid=corr_pyramid, corr_block=corr_block).eval().to(dev)
if model_mode == "scripted":
model = torch.jit.script(model)
img1 = torch.rand(1, 3, 64, 64).to(dev)
img2 = torch.rand(1, 3, 64, 64).to(dev)
num_iters = 3
preds = model(img1, img2, num_iters=num_iters)
depth_pred = preds[-1]
assert len(preds) == num_iters, "Number of predictions should be the same as model.num_iters"
assert depth_pred.shape == torch.Size(
[1, 1, 64, 64]
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
# Test against expected file output
TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_cuda())
def test_crestereo(model_fn, model_mode, dev):
set_rng_seed(0)
model = model_fn().eval().to(dev)
if model_mode == "scripted":
model = torch.jit.script(model)
img1 = torch.rand(1, 3, 64, 64).to(dev)
img2 = torch.rand(1, 3, 64, 64).to(dev)
iterations = 3
preds = model(img1, img2, flow_init=None, num_iters=iterations)
disparity_pred = preds[-1]
# all the pyramid levels except the highest res make only half the number of iterations
expected_iterations = (iterations // 2) * (len(model.resolutions) - 1)
expected_iterations += iterations
assert (
len(preds) == expected_iterations
), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels"
assert disparity_pred.shape == torch.Size(
[1, 2, 64, 64]
), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}"
assert all(
d.shape == torch.Size([1, 2, 64, 64]) for d in preds
), "All predicted disparities are expected to have the same shape"
# test a backward pass with a dummy loss as well
preds = torch.stack(preds, dim=0)
targets = torch.ones_like(preds, requires_grad=False)
loss = torch.nn.functional.mse_loss(preds, targets)
try:
loss.backward()
except Exception as e:
assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}"
TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
import re
import PIL.Image
import pytest
import torch
from common_utils import assert_equal
from prototype_common_utils import make_label
from torchvision.prototype import transforms, tv_tensors
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_boxes,
make_detection_mask,
make_image,
make_video,
)
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
[
pytest.param(
transform,
input,
id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}",
)
for transform, inputs in transforms_with_inputs
for idx, input in enumerate(inputs)
],
)
class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
return PIL.Image.new("RGB", (32, 32), 123)
return mocker.MagicMock(spec=image_type)
def test__extract_image_targets_assertion(self, mocker):
transform = transforms.SimpleCopyPaste()
flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, Image),
# labels, bboxes, masks
mocker.MagicMock(spec=tv_tensors.Label),
mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask),
# labels, bboxes, masks
mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask),
]
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [Image, PIL.Image.Image, torch.Tensor])
@pytest.mark.parametrize("label_type", [tv_tensors.Label, tv_tensors.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste()
flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes, masks
mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask),
# labels, bboxes, masks
mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask),
]
images, targets = transform._extract_image_targets(flat_sample)
assert len(images) == len(targets) == 2
if image_type == PIL.Image.Image:
torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0]))
torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1]))
else:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]
for target in targets:
for key, type_ in [
("boxes", BoundingBoxes),
("masks", Mask),
("labels", label_type),
]:
assert key in target
assert isinstance(target[key], type_)
assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [tv_tensors.Label, tv_tensors.OneHotLabel])
def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1
labels = torch.tensor([1, 2])
blending = True
resize_interpolation = InterpolationMode.BILINEAR
antialias = None
if label_type == tv_tensors.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
"boxes": BoundingBoxes(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", canvas_size=(32, 32)
),
"masks": Mask(masks),
"labels": label_type(labels),
}
paste_image = 10 * torch.ones(3, 32, 32)
paste_masks = torch.zeros(2, 32, 32)
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4])
if label_type == tv_tensors.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = {
"boxes": BoundingBoxes(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", canvas_size=(32, 32)
),
"masks": Mask(paste_masks),
"labels": label_type(paste_labels),
}
transform = transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(
image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias
)
assert output_image.unique().tolist() == [2, 10]
assert output_target["boxes"].shape == (4, 4)
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == tv_tensors.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
assert output_target["masks"].shape == (4, 32, 32)
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])
class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
canvas_size = (11, 5)
transform = transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [
make_image(size=canvas_size, color_space="RGB"),
make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
]
params = transform._get_params(flat_inputs)
assert params["needs_crop"]
assert params["height"] <= crop_size[0]
assert params["width"] <= crop_size[1]
assert (
isinstance(params["is_valid"], torch.Tensor)
and params["is_valid"].dtype is torch.bool
and params["is_valid"].shape == batch_shape
)
assert params["needs_pad"]
assert any(pad > 0 for pad in params["padding"])
def test__transform_culling(self, mocker):
batch_size = 10
canvas_size = (10, 10)
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=canvas_size[0],
width=canvas_size[1],
is_valid=is_valid,
needs_pad=False,
),
)
bounding_boxes = make_bounding_boxes(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
output = transform(
dict(
bounding_boxes=bounding_boxes,
masks=masks,
labels=labels,
)
)
assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
assert_equal(output["masks"], masks[is_valid])
assert_equal(output["labels"], labels[is_valid])
def test__transform_bounding_boxes_clamping(self, mocker):
batch_size = 3
canvas_size = (10, 10)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=canvas_size[0],
width=canvas_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)
bounding_boxes = make_bounding_boxes(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
)
mock = mocker.patch(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes
)
transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
transform(bounding_boxes)
mock.assert_called_once()
class TestLabelToOneHot:
def test__transform(self):
categories = ["apple", "pear", "pineapple"]
labels = tv_tensors.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot()
ohe_labels = transform(labels)
assert isinstance(ohe_labels, tv_tensors.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories
class TestPermuteDimensions:
@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{Image: (2, 1, 0), Video: None},
{Image: (2, 1, 0), Video: None},
),
(
{Image: (2, 1, 0), Video: (1, 2, 3, 0)},
{Image: (2, 1, 0), Video: (3, 0, 1, 2)},
),
],
)
def test_call(self, dims, inverse_dims):
sample = dict(
image=make_image(),
bounding_boxes=make_bounding_boxes(format=BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
if check_type(value, (Image, is_pure_tensor, Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.PermuteDimensions(dims=(1, 2, 0))
assert transform(tensor).shape == (3, 4, 2)
@pytest.mark.parametrize("other_type", [Image, Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
class TestTransposeDimensions:
@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{Image: (1, 2), Video: None},
],
)
def test_call(self, dims):
sample = dict(
image=make_image(),
bounding_boxes=make_bounding_boxes(format=BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type)
if check_type(value, (Image, is_pure_tensor, Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.TransposeDimensions(dims=(0, 2))
assert transform(tensor).shape == (4, 3, 2)
@pytest.mark.parametrize("other_type", [Image, Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
import importlib.machinery
import importlib.util
from pathlib import Path
def import_transforms_from_references(reference):
HERE = Path(__file__).parent
PROJECT_ROOT = HERE.parent
loader = importlib.machinery.SourceFileLoader(
"transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
)
spec = importlib.util.spec_from_loader("transforms", loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)
return module
det_transforms = import_transforms_from_references("detection")
def test_fixed_sized_crop_against_detection_reference():
def make_tv_tensors():
size = (600, 800)
num_objects = 22
pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
yield (tensor_image, target)
tv_tensor_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
yield (tv_tensor_image, target)
t = transforms.FixedSizeCrop((1024, 1024), fill=0)
t_ref = det_transforms.FixedSizeCrop((1024, 1024), fill=0)
for dp in make_tv_tensors():
# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
output = t(dp)
torch.manual_seed(12)
expected_output = t_ref(*dp)
assert_equal(expected_output, output)
from . import models, transforms, tv_tensors, utils
# Status of prototype datasets
These prototype datasets are based on [torchdata](https://github.com/pytorch/data)'s datapipes. Torchdata
development [is
paused](https://github.com/pytorch/data/#torchdata-see-note-below-on-current-status)
as of July 2023, so we are not actively maintaining this module. There is no
estimated date for a stable release of these datasets.
try:
import torchdata
except ModuleNotFoundError:
raise ModuleNotFoundError(
"`torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). "
"You can install it with `pip install --pre torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu"
) from None
from . import utils
from ._home import home
# Load this last, since some parts depend on the above being loaded first
from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip
from ._folder import from_data_folder, from_image_folder
from ._builtin import *
import pathlib
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion
T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset])
BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}
def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]:
def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]:
BUILTIN_INFOS[name] = fn()
return fn
return wrapper
BUILTIN_DATASETS = {}
def register_dataset(name: str) -> Callable[[D], D]:
def wrapper(dataset_cls: D) -> D:
BUILTIN_DATASETS[name] = dataset_cls
return dataset_cls
return wrapper
def list_datasets() -> List[str]:
return sorted(BUILTIN_DATASETS.keys())
def find(dct: Dict[str, T], name: str) -> T:
name = name.lower()
try:
return dct[name]
except KeyError as error:
raise ValueError(
add_suggestion(
f"Unknown dataset '{name}'.",
word=name,
possibilities=dct.keys(),
alternative_hint=lambda _: (
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
),
)
) from error
def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
dataset_cls = find(BUILTIN_DATASETS, name)
if root is None:
root = pathlib.Path(home()) / name
return dataset_cls(root, **config)
# How to add new built-in prototype datasets
As the name implies, the datasets are still in a prototype state and thus subject to rapid change. This in turn means
that this document will also change a lot.
If you hit a blocker while adding a dataset, please have a look at another similar dataset to see how it is implemented
there. If you can't resolve it yourself, feel free to send a draft PR in order for us to help you out.
Finally, `from torchvision.prototype import datasets` is implied below.
## Implementation
Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin`
that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that
module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in
detail below:
```python
import pathlib
from typing import Any, BinaryIO, Dict, List, Tuple, Union
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype.datasets.utils import Dataset, OnlineResource
from .._api import register_dataset, register_info
NAME = "my-dataset"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(
...
)
@register_dataset(NAME)
class MyDataset(Dataset):
def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None:
...
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
...
def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]:
...
def __len__(self) -> int:
...
```
In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a
dictionary of static information. The most common use case is to provide human-readable categories.
[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories.
Finally, both the dataset class and the info function need to be registered on the API with the respective decorators.
With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively.
### `__init__(self, root, *, ..., skip_integrity_check = False)`
Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the
base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as
setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke
the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with
an underscore.
If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base
class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically
checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to
avoid missing dependencies at import time.
### `_resources(self)`
Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be
build. The download will happen automatically.
Currently, the following `OnlineResource`'s are supported:
- `HttpResource`: Used for files that are directly exposed through HTTP(s) and only requires the URL.
- `GDriveResource`: Used for files that are hosted on GDrive and requires the GDrive ID as well as the `file_name`.
- `ManualDownloadResource`: Used files are not publicly accessible and requires instructions how to download them
manually. If the file does not exist, an error will be raised with the supplied instructions.
- `KaggleDownloadResource`: Used for files that are available on Kaggle. This inherits from `ManualDownloadResource`.
Although optional in general, all resources used in the built-in datasets should comprise
[SHA256](https://en.wikipedia.org/wiki/SHA-2) checksum for security. It will be automatically checked after the
download. You can compute the checksum with system utilities e.g `sha256-sum`, or this snippet:
```python
import hashlib
def sha256sum(path, chunk_size=1024 * 1024):
checksum = hashlib.sha256()
with open(path, "rb") as f:
while chunk := f.read(chunk_size):
checksum.update(chunk)
print(checksum.hexdigest())
```
### `_datapipe(self, resource_dps)`
This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared
to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone
that is working with them rather than on them, `IterDataPipe`'s behave just as generators, i.e. you can't do anything
with them besides iterating.
Of course, there are some common building blocks that should suffice in 95% of the cases. The most used are:
- `Mapper`: Apply a callable to every item in the datapipe.
- `Filter`: Keep only items that satisfy a condition.
- `Demultiplexer`: Split a datapipe into multiple ones.
- `IterKeyZipper`: Merge two datapipes into one.
All of them can be imported `from torchdata.datapipes.iter`. In addition, use `functools.partial` in case a callable
needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated
to add one. See the MNIST or CelebA datasets for example.
`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return
value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain
tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one
of such tuples for the file specified by the resource.
Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and
`Grouper`. There are two issues with that:
1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely.
2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime.
Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than
trying to zip already loaded images.
There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and
`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding
should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal`
and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`.
Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the
names (yet!).
### `__len__`
This returns an integer denoting the number of samples that can be drawn from the dataset. Please use
[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the
readability. For example, `1_281_167` vs. `1281167`.
If there are only two different numbers, a simple `if` / `else` is fine:
```py
def __len__(self):
return 12_345 if self._split == "train" else 6_789
```
If there are more options, using a dictionary usually is the most readable option:
```py
def __len__(self):
return {
"train": 3,
"val": 2,
"test": 1,
}[self._split]
```
If the number of samples depends on more than one parameter, you can use tuples as dictionary keys:
```py
def __len__(self):
return {
("train", "bar"): 4,
("train", "baz"): 3,
("test", "bar"): 2,
("test", "baz"): 1,
}[(self._split, self._foo)]
```
The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the
development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way
is to define a dummy method like
```py
def __len__(self):
return 1
```
and only fill it with the correct data if the implementation is otherwise finished.
[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples.
## Tests
To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data.
This mock-up should resemble the original data as close as necessary, while containing only few examples.
To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the
same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function".
Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset
will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options,
you can use the `combinations_grid()` helper function, e.g.
`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`.
In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass
the `name` parameter to `@register_mock`
```py
# this is defined in torchvision/prototype/datasets/_builtin
@register_dataset("my-dataset")
class MyDataset(Dataset):
...
@register_mock(name="my-dataset", configs=...)
def my_dataset(root, config):
...
```
The mock data function receives two arguments:
- `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data
needs to be placed.
- `config`: The configuration to generate the data for. This is one of the dictionaries defined in
`@register_mock(configs=...)`
The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if
the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of
the current `config`. Although this seems odd at first, this is important. Consider the following original data setup:
```
root
├── test
│ ├── test_image0.jpg
│ ...
└── train
├── train_image0.jpg
...
```
For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to
load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is
present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in
`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for
the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data.
For datasets that are ported from the old API, we already have some mock data in
[`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there
and have a look at the `inject_fake_data` function. There are a few differences though:
- `tmp_dir` corresponds to `root`, but is a `str` rather than a
[`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like
`folder = pathlib.Path(tmp_dir)`. This is not needed.
- The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the
new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files
specified in the dataset.
- As explained in the paragraph above, the generated data is often "incomplete" and only valid for given the config.
Make sure you follow the instructions above.
The function should return an integer indicating the number of samples in the dataset for the current `config`.
Preferably, this number should be different for different `config`'s to have more confidence in the dataset
implementation.
Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets.py -k {name}`.
## FAQ
### How do I start?
Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do
`return resources_dp[0]` to get started. Then import the dataset class in
`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be
instantiable via `datasets.load("mydataset")`. On a separate script, try something like
```py
from torchvision.prototype import datasets
dataset = datasets.load("mydataset")
for sample in dataset:
print(sample) # this is the content of an item in datapipe returned by _datapipe()
break
# Or you can also inspect the sample in a debugger
```
This will give you an idea of what the first datapipe in `resources_dp` contains. You can also do that with
`resources_dp[1]` or `resources_dp[2]` (etc.) if they exist. Then follow the instructions above to manipulate these
datapipes and return the appropriate dictionary format.
### How do I handle a dataset that defines many categories?
As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more
categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a
category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file`
function and pass it `$NAME`.
In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where
each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method
should return a sequence of strings representing the category names. In the method body, you'll have to manually load
the resources, e.g.
```py
resources = self._resources()
dp = resources[0].load(self._root)
```
Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes
sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that.
To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`.
### What if a resource file forms an I/O bottleneck?
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the
`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
accepts `"decompress"` and `"extract"` to handle these common scenarios.
### How do I compute the number of samples?
Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way
than to iterate over the dataset and count the number of samples:
```py
import itertools
from torchvision.prototype import datasets
def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there
configs = combinations_grid(split=("train", "test"), foo=("bar", "baz"))
for config in configs:
dataset = datasets.load("my-dataset", **config)
num_samples = 0
for _ in dataset:
num_samples += 1
print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples)
```
To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation
files.
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco
from .country211 import Country211
from .cub200 import CUB200
from .dtd import DTD
from .eurosat import EuroSAT
from .fer2013 import FER2013
from .food101 import Food101
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .svhn import SVHN
from .usps import USPS
from .voc import VOC
import pathlib
import re
from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
@register_info("caltech101")
def _caltech101_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech101"))
@register_dataset("caltech101")
class Caltech101(Dataset):
"""
- **homepage**: https://data.caltech.edu/records/20086
- **dependencies**:
- <scipy `https://scipy.org/`>_
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech101_info()["categories"]
super().__init__(
root,
dependencies=("scipy",),
skip_integrity_check=skip_integrity_check,
)
def _resources(self) -> List[OnlineResource]:
images = GDriveResource(
"137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
file_name="101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
preprocess="decompress",
)
anns = GDriveResource(
"175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
file_name="Annotations.tar",
sha256="1717f4e10aa837b05956e3f4c94456527b143eec0d95e935028b30aff40663d8",
)
return [images, anns]
_IMAGES_NAME_PATTERN = re.compile(r"image_(?P<id>\d+)[.]jpg")
_ANNS_NAME_PATTERN = re.compile(r"annotation_(?P<id>\d+)[.]mat")
_ANNS_CATEGORY_MAP = {
"Faces_2": "Faces",
"Faces_3": "Faces_easy",
"Motorbikes_16": "Motorbikes",
"Airplanes_Side_2": "airplanes",
}
def _is_not_background_image(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.parent.name != "BACKGROUND_Google"
def _is_ann(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return bool(self._ANNS_NAME_PATTERN.match(path.name))
def _images_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
path = pathlib.Path(data[0])
category = path.parent.name
id = self._IMAGES_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
return category, id
def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:
path = pathlib.Path(data[0])
category = path.parent.name
if category in self._ANNS_CATEGORY_MAP:
category = self._ANNS_CATEGORY_MAP[category]
id = self._ANNS_NAME_PATTERN.match(path.name).group("id") # type: ignore[union-attr]
return category, id
def _prepare_sample(
self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]]
) -> Dict[str, Any]:
key, (image_data, ann_data) = data
category, _ = key
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
image = EncodedImage.from_file(image_buffer)
ann = read_mat(ann_buffer)
return dict(
label=Label.from_category(category, categories=self._categories),
image_path=image_path,
image=image,
ann_path=ann_path,
bounding_boxes=BoundingBoxes(
ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]],
format="xyxy",
spatial_size=image.spatial_size,
),
contour=torch.as_tensor(ann["obj_contour"].T),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
anns_dp = Filter(anns_dp, self._is_ann)
dp = IterKeyZipper(
images_dp,
anns_dp,
key_fn=self._images_key_fn,
ref_key_fn=self._anns_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 8677
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
@register_info("caltech256")
def _caltech256_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("caltech256"))
@register_dataset("caltech256")
class Caltech256(Dataset):
"""
- **homepage**: https://data.caltech.edu/records/20087
"""
def __init__(
self,
root: Union[str, pathlib.Path],
skip_integrity_check: bool = False,
) -> None:
self._categories = _caltech256_info()["categories"]
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
return [
GDriveResource(
"1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
file_name="256_ObjectCategories.tar",
sha256="08ff01b03c65566014ae88eb0490dbe4419fc7ac4de726ee1163e39fd809543e",
)
]
def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name != "RENAME2"
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 30607
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
return [name.split(".")[1] for name in sorted(dir_names)]
Faces
Faces_easy
Leopards
Motorbikes
accordion
airplanes
anchor
ant
barrel
bass
beaver
binocular
bonsai
brain
brontosaurus
buddha
butterfly
camera
cannon
car_side
ceiling_fan
cellphone
chair
chandelier
cougar_body
cougar_face
crab
crayfish
crocodile
crocodile_head
cup
dalmatian
dollar_bill
dolphin
dragonfly
electric_guitar
elephant
emu
euphonium
ewer
ferry
flamingo
flamingo_head
garfield
gerenuk
gramophone
grand_piano
hawksbill
headphone
hedgehog
helicopter
ibis
inline_skate
joshua_tree
kangaroo
ketch
lamp
laptop
llama
lobster
lotus
mandolin
mayfly
menorah
metronome
minaret
nautilus
octopus
okapi
pagoda
panda
pigeon
pizza
platypus
pyramid
revolver
rhino
rooster
saxophone
schooner
scissors
scorpion
sea_horse
snoopy
soccer_ball
stapler
starfish
stegosaurus
stop_sign
strawberry
sunflower
tick
trilobite
umbrella
watch
water_lilly
wheelchair
wild_cat
windsor_chair
wrench
yin_yang
ak47
american-flag
backpack
baseball-bat
baseball-glove
basketball-hoop
bat
bathtub
bear
beer-mug
billiards
binoculars
birdbath
blimp
bonsai-101
boom-box
bowling-ball
bowling-pin
boxing-glove
brain-101
breadmaker
buddha-101
bulldozer
butterfly
cactus
cake
calculator
camel
cannon
canoe
car-tire
cartman
cd
centipede
cereal-box
chandelier-101
chess-board
chimp
chopsticks
cockroach
coffee-mug
coffin
coin
comet
computer-keyboard
computer-monitor
computer-mouse
conch
cormorant
covered-wagon
cowboy-hat
crab-101
desk-globe
diamond-ring
dice
dog
dolphin-101
doorknob
drinking-straw
duck
dumb-bell
eiffel-tower
electric-guitar-101
elephant-101
elk
ewer-101
eyeglasses
fern
fighter-jet
fire-extinguisher
fire-hydrant
fire-truck
fireworks
flashlight
floppy-disk
football-helmet
french-horn
fried-egg
frisbee
frog
frying-pan
galaxy
gas-pump
giraffe
goat
golden-gate-bridge
goldfish
golf-ball
goose
gorilla
grand-piano-101
grapes
grasshopper
guitar-pick
hamburger
hammock
harmonica
harp
harpsichord
hawksbill-101
head-phones
helicopter-101
hibiscus
homer-simpson
horse
horseshoe-crab
hot-air-balloon
hot-dog
hot-tub
hourglass
house-fly
human-skeleton
hummingbird
ibis-101
ice-cream-cone
iguana
ipod
iris
jesus-christ
joy-stick
kangaroo-101
kayak
ketch-101
killer-whale
knife
ladder
laptop-101
lathe
leopards-101
license-plate
lightbulb
light-house
lightning
llama-101
mailbox
mandolin
mars
mattress
megaphone
menorah-101
microscope
microwave
minaret
minotaur
motorbikes-101
mountain-bike
mushroom
mussels
necktie
octopus
ostrich
owl
palm-pilot
palm-tree
paperclip
paper-shredder
pci-card
penguin
people
pez-dispenser
photocopier
picnic-table
playing-card
porcupine
pram
praying-mantis
pyramid
raccoon
radio-telescope
rainbow
refrigerator
revolver-101
rifle
rotary-phone
roulette-wheel
saddle
saturn
school-bus
scorpion-101
screwdriver
segway
self-propelled-lawn-mower
sextant
sheet-music
skateboard
skunk
skyscraper
smokestack
snail
snake
sneaker
snowmobile
soccer-ball
socks
soda-can
spaghetti
speed-boat
spider
spoon
stained-glass
starfish-101
steering-wheel
stirrups
sunflower-101
superman
sushi
swan
swiss-army-knife
sword
syringe
tambourine
teapot
teddy-bear
teepee
telephone-box
tennis-ball
tennis-court
tennis-racket
theodolite
toaster
tomato
tombstone
top-hat
touring-bike
tower-pisa
traffic-light
treadmill
triceratops
tricycle
trilobite-101
tripod
t-shirt
tuning-fork
tweezer
umbrella-101
unicorn
vcr
video-projector
washing-machine
watch-101
waterfall
watermelon
welding-mask
wheelbarrow
windmill
wine-bottle
xylophone
yarmulke
yo-yo
zebra
airplanes-101
car-side-101
faces-easy-101
greyhound
tennis-shoes
toad
clutter
import csv
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__(
self,
datapipe: IterDataPipe[Tuple[Any, BinaryIO]],
*,
fieldnames: Optional[Sequence[str]] = None,
) -> None:
self.datapipe = datapipe
self.fieldnames = fieldnames
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
for _, file in self.datapipe:
try:
lines = (line.decode() for line in file)
if self.fieldnames:
fieldnames = self.fieldnames
else:
# The first row is skipped, because it only contains the number of samples
next(lines)
# Empty field names are filtered out, because some files have an extra white space after the header
# line, which is recognized as extra column
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
# Some files do not include a label for the image ID column
if fieldnames[0] != "image_id":
fieldnames.insert(0, "image_id")
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
yield line.pop("image_id"), line
finally:
file.close()
NAME = "celeba"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CelebA(Dataset):
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
splits = GDriveResource(
"0B7EVK8r0v71pY0NSMzRuSXJEVkk",
sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7",
file_name="list_eval_partition.txt",
)
images = GDriveResource(
"0B7EVK8r0v71pZjFTYXZWM3FlRnM",
sha256="46fb89443c578308acf364d7d379fe1b9efb793042c0af734b6112e4fd3a8c74",
file_name="img_align_celeba.zip",
)
identities = GDriveResource(
"1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS",
sha256="c6143857c3e2630ac2da9f782e9c1232e5e59be993a9d44e8a7916c78a6158c0",
file_name="identity_CelebA.txt",
)
attributes = GDriveResource(
"0B7EVK8r0v71pblRyaVFSWGxPY0U",
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt",
)
bounding_boxes = GDriveResource(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt",
)
landmarks = GDriveResource(
"0B7EVK8r0v71pd0FJY3Blby1HUTQ",
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt",
)
return [splits, images, identities, attributes, bounding_boxes, landmarks]
def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool:
split_id = {
"train": "0",
"val": "1",
"test": "2",
}[self._split]
return data[1]["split_id"] == split_id
def _prepare_sample(
self,
data: Tuple[
Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]],
Tuple[
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
],
],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, (_, image_data) = split_and_image_data
path, buffer = image_data
image = EncodedImage.from_file(buffer)
(_, identity), (_, attributes), (_, bounding_boxes), (_, landmarks) = ann_data
return dict(
path=path,
image=image,
identity=Label(int(identity["identity"])),
attributes={attr: value == "1" for attr, value in attributes.items()},
bounding_boxes=BoundingBoxes(
[int(bounding_boxes[key]) for key in ("x_1", "y_1", "width", "height")],
format="xywh",
spatial_size=image.spatial_size,
),
landmarks={
landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split)
splits_dp = hint_shuffling(splits_dp)
splits_dp = hint_sharding(splits_dp)
anns_dp = Zipper(
*[
CelebACSVParser(dp, fieldnames=fieldnames)
for dp, fieldnames in (
(identities_dp, ("image_id", "identity")),
(attributes_dp, None),
(bounding_boxes_dp, None),
(landmarks_dp, None),
)
]
)
dp = IterKeyZipper(
splits_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
dp = IterKeyZipper(
dp,
anns_dp,
key_fn=getitem(0),
ref_key_fn=getitem(0, 0),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return {
"train": 162_770,
"val": 19_867,
"test": 19_962,
}[self._split]
import abc
import io
import pathlib
import pickle
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
self.datapipe = datapipe
self.labels_key = labels_key
def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
for mapping in self.datapipe:
image_arrays = mapping["data"].reshape((-1, 3, 32, 32))
category_idcs = mapping[self.labels_key]
yield from iter(zip(image_arrays, category_idcs))
class _CifarBase(Dataset):
_FILE_NAME: str
_SHA256: str
_LABELS_KEY: str
_META_FILE_NAME: str
_CATEGORIES_KEY: str
_categories: List[str]
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
pass
def _resources(self) -> List[OnlineResource]:
return [
HttpResource(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
sha256=self._SHA256,
)
]
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
file.close()
return content
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
return dict(
image=Image(image_array),
label=Label(category_idx, categories=self._categories),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_data_file)
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_shuffling(dp)
dp = hint_sharding(dp)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 50_000 if self._split == "train" else 10_000
def _generate_categories(self) -> List[str]:
resources = self._resources()
dp = resources[0].load(self._root)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
@register_info("cifar10")
def _cifar10_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar10"))
@register_dataset("cifar10")
class Cifar10(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-10-python.tar.gz"
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
_categories = _cifar10_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if self._split == "train" else "test")
@register_info("cifar100")
def _cifar100_info() -> Dict[str, Any]:
return dict(categories=read_categories_file("cifar100"))
@register_dataset("cifar100")
class Cifar100(_CifarBase):
"""
- **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html
"""
_FILE_NAME = "cifar-100-python.tar.gz"
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
_categories = _cifar100_info()["categories"]
def _is_data_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name == self._split
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
apple
aquarium_fish
baby
bear
beaver
bed
bee
beetle
bicycle
bottle
bowl
boy
bridge
bus
butterfly
camel
can
castle
caterpillar
cattle
chair
chimpanzee
clock
cloud
cockroach
couch
crab
crocodile
cup
dinosaur
dolphin
elephant
flatfish
forest
fox
girl
hamster
house
kangaroo
keyboard
lamp
lawn_mower
leopard
lion
lizard
lobster
man
maple_tree
motorcycle
mountain
mouse
mushroom
oak_tree
orange
orchid
otter
palm_tree
pear
pickup_truck
pine_tree
plain
plate
poppy
porcupine
possum
rabbit
raccoon
ray
road
rocket
rose
sea
seal
shark
shrew
skunk
skyscraper
snail
snake
spider
squirrel
streetcar
sunflower
sweet_pepper
table
tank
telephone
television
tiger
tractor
train
trout
tulip
turtle
wardrobe
whale
willow_tree
wolf
woman
worm
import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
hint_sharding,
hint_shuffling,
INFINITE_BUFFER_SIZE,
path_accessor,
path_comparator,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
NAME = "clevr"
@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict()
@register_dataset(NAME)
class CLEVR(Dataset):
"""
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"""
def __init__(
self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False
) -> None:
self._split = self._verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, skip_integrity_check=skip_integrity_check)
def _resources(self) -> List[OnlineResource]:
archive = HttpResource(
"https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip",
sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1",
)
return [archive]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = pathlib.Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parent.name == "scenes":
return 1
else:
return None
def _filter_scene_anns(self, data: Tuple[str, Any]) -> bool:
key, _ = data
return key == "scenes"
def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]:
return data, None
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
archive_dp,
2,
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
images_dp = Filter(images_dp, path_comparator("parent.name", self._split))
images_dp = hint_shuffling(images_dp)
images_dp = hint_sharding(images_dp)
if self._split != "test":
scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json"))
scenes_dp = JsonParser(scenes_dp)
scenes_dp = Mapper(scenes_dp, getitem(1, "scenes"))
scenes_dp = UnBatcher(scenes_dp)
dp = IterKeyZipper(
images_dp,
scenes_dp,
key_fn=path_accessor("name"),
ref_key_fn=getitem("image_filename"),
buffer_size=INFINITE_BUFFER_SIZE,
)
else:
for _, file in scenes_dp:
file.close()
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, self._prepare_sample)
def __len__(self) -> int:
return 70_000 if self._split == "train" else 15_000
__background__,N/A
person,person
bicycle,vehicle
car,vehicle
motorcycle,vehicle
airplane,vehicle
bus,vehicle
train,vehicle
truck,vehicle
boat,vehicle
traffic light,outdoor
fire hydrant,outdoor
N/A,N/A
stop sign,outdoor
parking meter,outdoor
bench,outdoor
bird,animal
cat,animal
dog,animal
horse,animal
sheep,animal
cow,animal
elephant,animal
bear,animal
zebra,animal
giraffe,animal
N/A,N/A
backpack,accessory
umbrella,accessory
N/A,N/A
N/A,N/A
handbag,accessory
tie,accessory
suitcase,accessory
frisbee,sports
skis,sports
snowboard,sports
sports ball,sports
kite,sports
baseball bat,sports
baseball glove,sports
skateboard,sports
surfboard,sports
tennis racket,sports
bottle,kitchen
N/A,N/A
wine glass,kitchen
cup,kitchen
fork,kitchen
knife,kitchen
spoon,kitchen
bowl,kitchen
banana,food
apple,food
sandwich,food
orange,food
broccoli,food
carrot,food
hot dog,food
pizza,food
donut,food
cake,food
chair,furniture
couch,furniture
potted plant,furniture
bed,furniture
N/A,N/A
dining table,furniture
N/A,N/A
N/A,N/A
toilet,furniture
N/A,N/A
tv,electronic
laptop,electronic
mouse,electronic
remote,electronic
keyboard,electronic
cell phone,electronic
microwave,appliance
oven,appliance
toaster,appliance
sink,appliance
refrigerator,appliance
N/A,N/A
book,indoor
clock,indoor
vase,indoor
scissors,indoor
teddy bear,indoor
hair drier,indoor
toothbrush,indoor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment