Unverified Commit d32bc4ba authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Revamp prototype features and transforms (#5407)

* revamp prototype features (#5283)

* remove decoding from prototype datasets (#5287)

* remove decoder from prototype datasets

* remove unused imports

* cleanup

* fix readme

* use OneHotLabel in SEMEION

* improve voc implementation

* revert unrelated changes

* fix semeion mock data

* fix pcam

* readd functional transforms API to prototype (#5295)

* readd functional transforms

* cleanup

* add missing imports

* remove __torch_function__ dispatch

* readd repr

* readd empty line

* add test for scriptability

* remove function copy

* change import from functional tensor transforms to just functional

* fix import

* fix test

* fix prototype features and functional transforms after review (#5377)

* fix prototype functional transforms after review

* address features review

* make mypy more strict on prototype features

* make mypy more strict for prototype transforms

* fix annotation

* fix kernel tests

* add automatic feature type dispatch to functional transforms (#5323)

* add auto dispatch

* fix missing arguments error message

* remove pil kernel for erase

* automate feature specific parameter detection

* fix typos

* cleanup dispatcher call

* remove __torch_function__ from transform dispatch

* remove auto-generation

* revert unrelated changes

* remove implements decorator

* change register parameter order

* change order of transforms for readability

* add documentation for __torch_function__

* fix mypy

* inline check for support

* refactor kernel registering process

* refactor dispatch to be a regular decorator

* split kernels and dispatchers

* remove sentinels

* replace pass with ...

* appease mypy

* make single kernel dispatchers more concise

* make dispatcher signatures more generic

* make kernel checking more strict

* revert doc changes

* address Franciscos comments

* remove inplace

* rename kernel test module

* fix inplace

* remove special casing for pil and vanilla tensors

* address comments

* update docs

* cleanup features / transforms feature branch (#5406)

* mark candidates for removal

* align signature of resize_bounding_box with corresponding image kernel

* fix documentation of Feature

* remove interpolation mode and antialias option from resize_segmentation_mask

* remove or privatize functionality in features / datasets / transforms
parent f2f490b1
......@@ -6,6 +6,36 @@ pretty = True
allow_redefinition = True
warn_redundant_casts = True
[mypy-torchvision.prototype.features.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.prototype.transforms.*]
; untyped definitions and calls
disallow_untyped_defs = True
; None and Optional handling
no_implicit_optional = True
; warnings
warn_unused_ignores = True
warn_return_any = True
; miscellaneous strictness flags
allow_redefinition = True
[mypy-torchvision.prototype.datasets.*]
; untyped definitions and calls
......
......@@ -432,50 +432,52 @@ def caltech256(info, root, config):
@register_mock
def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train":
images_root = root / "ILSVRC2012_img_train"
from scipy.io import savemat
categories = info.categories
wnids = [info.extra.category_to_wnid[category] for category in categories]
if config.split == "train":
num_samples = len(wnids)
archive_name = "ILSVRC2012_img_train.tar"
files = []
for wnid in wnids:
files = create_image_folder(
root=images_root,
create_image_folder(
root=root,
name=wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
num_examples=1,
)
make_tar(images_root, f"{wnid}.tar", files[0].parent)
files.append(make_tar(root, f"{wnid}.tar"))
elif config.split == "val":
num_samples = 3
files = create_image_folder(
root=root,
name="ILSVRC2012_img_val",
file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
images_root = files[0].parent
else: # config.split == "test"
images_root = root / "ILSVRC2012_img_test_v10102019"
archive_name = "ILSVRC2012_img_val.tar"
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
num_samples = 3
devkit_root = root / "ILSVRC2012_devkit_t12"
data_root = devkit_root / "data"
data_root.mkdir(parents=True)
create_image_folder(
root=images_root,
name="test",
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
make_tar(root, f"{images_root.name}.tar", images_root)
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
num_children = 0
synsets = [
(idx, wnid, category, "", num_children, [], 0, 0)
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
]
num_children = 1
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
savemat(data_root / "meta.mat", dict(synsets=synsets))
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
else: # config.split == "test"
num_samples = 5
archive_name = "ILSVRC2012_img_test_v10102019.tar"
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
make_tar(root, archive_name, *files)
return num_samples
......@@ -667,14 +669,15 @@ def sbd(info, root, config):
@register_mock
def semeion(info, root, config):
num_samples = 3
num_categories = len(info.categories)
images = torch.rand(num_samples, 256)
labels = one_hot(torch.randint(len(info.categories), size=(num_samples,)))
labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories)
with open(root / "semeion.data", "w") as fh:
for image, one_hot_label in zip(images, labels):
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
labels_columns = " ".join([str(label.item()) for label in one_hot_label])
fh.write(f"{image_columns} {labels_columns}\n")
fh.write(f"{image_columns} {labels_columns} \n")
return num_samples
......@@ -729,32 +732,33 @@ class VOCMockData:
def _make_detection_ann_file(cls, root, name):
def add_child(parent, name, text=None):
child = ET.SubElement(parent, name)
child.text = text
child.text = str(text)
return child
def add_name(obj, name="dog"):
add_child(obj, "name", name)
return name
def add_bndbox(obj, bndbox=None):
if bndbox is None:
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
def add_size(obj):
obj = add_child(obj, "size")
size = {"width": 0, "height": 0, "depth": 3}
for name, text in size.items():
add_child(obj, name, text)
def add_bndbox(obj):
obj = add_child(obj, "bndbox")
bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4}
for name, text in bndbox.items():
add_child(obj, name, text)
return bndbox
annotation = ET.Element("annotation")
add_size(annotation)
obj = add_child(annotation, "object")
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
add_name(obj)
add_bndbox(obj)
with open(root / name, "wb") as fh:
fh.write(ET.tostring(annotation))
return data
@classmethod
def generate(cls, root, *, year, trainval):
archive_folder = root
......
import functools
import io
from pathlib import Path
import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
......@@ -11,6 +13,11 @@ from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str
assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
)
@pytest.fixture
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
......@@ -92,6 +99,7 @@ class TestCommon:
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)
@pytest.mark.xfail
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
......@@ -137,6 +145,17 @@ class TestCommon:
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_save_load(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset))
with io.BytesIO() as buffer:
torch.save(sample, buffer)
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
......@@ -171,5 +190,5 @@ class TestGTSRB:
dataset = datasets.load(dataset_mock.name, **config)
for sample in dataset:
label_from_path = int(Path(sample["image_path"]).parent.name)
label_from_path = int(Path(sample["path"]).parent.name)
assert sample["label"] == label_from_path
......@@ -5,8 +5,8 @@ from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch
def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs)
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)
class TestFrozenMapping:
......@@ -176,7 +176,7 @@ class TestDataset:
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass
def _make_datapipe(self, resource_dps, *, config, decoder):
def _make_datapipe(self, resource_dps, *, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass
......@@ -229,12 +229,3 @@ class TestDataset:
(call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel
def test_decoder(self):
dataset = self.DatasetMock()
sentinel = object()
dataset.load("", decoder=sentinel)
(_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["decoder"] is sentinel
import functools
import itertools
import pytest
import torch
from torch.testing import make_tensor as _make_tensor, assert_close
from torchvision.prototype import features
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)
def make_image(**kwargs):
data = make_tensor((3, *torch.randint(16, 33, (2,)).tolist()))
return features.Image(data, **kwargs)
def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, ())
y1 = torch.randint(0, height // 2, ())
x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1
y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, ())
y = torch.randint(0, height // 2, ())
w = torch.randint(1, width - int(x), ())
h = torch.randint(1, height - int(y), ())
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = torch.randint(1, min(int(cx), width - int(cx)), ())
h = torch.randint(1, min(int(cy), height - int(cy)), ())
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
parts = make_tensor((4,)).unbind()
return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size)
MAKE_DATA_MAP = {
features.Image: make_image,
features.BoundingBox: make_bounding_box,
}
def make_feature(feature_type, **meta_data):
maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data))
return maker(**meta_data)
class TestCommon:
FEATURE_TYPES, NON_DEFAULT_META_DATA = zip(
*(
(features.Image, dict(color_space=features.ColorSpace._SENTINEL)),
(features.Label, dict(category="category")),
(features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))),
)
)
feature_types = pytest.mark.parametrize(
"feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__
)
features = pytest.mark.parametrize(
"feature",
[
pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__)
for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA)
],
)
def test_consistency(self):
builtin_feature_types = {
name
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}
untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES}
if untested_feature_types:
raise AssertionError(
f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} "
f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. "
f"Please add it/them to `TestCommon.FEATURE_TYPES`."
)
@features
def test_meta_data_attribute_access(self, feature):
for name, value in feature._meta_data.items():
assert getattr(feature, name) == feature._meta_data[name]
@feature_types
def test_torch_function(self, feature_type):
input = make_feature(feature_type)
# This can be any Tensor operation besides clone
output = input + 1
assert type(output) is torch.Tensor
assert_close(output, input + 1)
@feature_types
def test_clone(self, feature_type):
input = make_feature(feature_type)
output = input.clone()
assert type(output) is feature_type
assert_close(output, input)
assert output._meta_data == input._meta_data
@features
def test_serialization(self, tmpdir, feature):
file = tmpdir / "test_serialization.pt"
torch.save(feature, str(file))
loaded_feature = torch.load(str(file))
assert isinstance(loaded_feature, type(feature))
assert_close(loaded_feature, feature)
assert loaded_feature._meta_data == feature._meta_data
@features
def test_repr(self, feature):
assert type(feature).__name__ in repr(feature)
class TestBoundingBox:
@pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2))
def test_cycle_consistency(self, format, intermediate_format):
input = make_bounding_box(format=format)
output = input.convert(intermediate_format).convert(format)
assert_close(input, output)
# For now, tensor subclasses with additional meta data do not work with torchscript.
# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037.
@pytest.mark.xfail
class TestJit:
def test_bounding_box(self):
def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
old_height, old_width = input.image_size
new_height, new_width = size
height_scale = new_height / old_height
width_scale = new_width / old_width
old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts()
new_x1 = old_x1 * width_scale
new_y1 = old_y1 * height_scale
new_x2 = old_x2 * width_scale
new_y2 = old_y2 * height_scale
return features.BoundingBox.from_parts(
new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist())
)
def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox:
x, y, w, h = input.convert("xywh").to_parts()
x = input.image_size[1] - (x + w)
return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh")
def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
return horizontal_flip(resize(input, size)).convert("xyxy")
image_size = (8, 6)
input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size)
size = torch.tensor((4, 12))
expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size)
actual_eager = compose(input, size)
assert_close(actual_eager, expected)
sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5)))
actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size)
assert_close(actual_jit, expected)
import pytest
from torchvision.prototype import transforms, features
from torchvision.prototype.utils._internal import sequence_to_str
FEATURE_TYPES = {
feature_type
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}
TRANSFORM_TYPES = tuple(
transform_type
for name, transform_type in transforms.__dict__.items()
if not name.startswith("_")
and isinstance(transform_type, type)
and issubclass(transform_type, transforms.Transform)
and transform_type is not transforms.Transform
)
def test_feature_type_support():
missing_feature_types = FEATURE_TYPES - set(transforms.Transform._BUILTIN_FEATURE_TYPES)
if missing_feature_types:
names = sorted([feature_type.__name__ for feature_type in missing_feature_types])
raise AssertionError(
f"The feature(s) {sequence_to_str(names, separate_last='and ')} is/are exposed at "
f"`torchvision.prototype.features`, but are missing in Transform._BUILTIN_FEATURE_TYPES. "
f"Please add it/them to the collection."
)
@pytest.mark.parametrize(
"transform_type",
[transform_type for transform_type in TRANSFORM_TYPES if transform_type is not transforms.Identity],
ids=lambda transform_type: transform_type.__name__,
)
def test_feature_no_op_coverage(transform_type):
unsupported_features = (
FEATURE_TYPES - transform_type.supported_feature_types() - set(transform_type.NO_OP_FEATURE_TYPES)
)
if unsupported_features:
names = sorted([feature_type.__name__ for feature_type in unsupported_features])
raise AssertionError(
f"The feature(s) {sequence_to_str(names, separate_last='and ')} are neither supported nor declared as "
f"no-op for transform `{transform_type.__name__}`. Please either implement a feature transform for them, "
f"or add them to the the `{transform_type.__name__}.NO_OP_FEATURE_TYPES` collection."
)
def test_non_feature_no_op():
class TestTransform(transforms.Transform):
@staticmethod
def image(input):
return input
no_op_sample = dict(int=0, float=0.0, bool=False, str="str")
assert TestTransform()(no_op_sample) == no_op_sample
import functools
import itertools
import pytest
import torch.testing
import torchvision.prototype.transforms.kernels as K
from torch import jit
from torchvision.prototype import features
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
size = size or torch.randint(16, 33, (2,)).tolist()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
num_channels = {
features.ColorSpace.GRAYSCALE: 1,
features.ColorSpace.RGB: 3,
}[color_space]
shape = (*extra_dims, num_channels, *size)
if dtype.is_floating_point:
data = torch.rand(shape, dtype=dtype)
else:
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype)
return features.Image(data, color_space=color_space)
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE)
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB)
def make_images(
sizes=((16, 16), (7, 33), (31, 9)),
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB),
dtypes=(torch.float32, torch.uint8),
extra_dims=((4,), (2, 3)),
):
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes):
yield make_image(size, color_space=color_space)
for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
low, high = torch.broadcast_tensors(
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))]
)
try:
return torch.stack(
[
torch.randint(low_scalar, high_scalar, (), **kwargs)
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist())
]
).reshape(low.shape)
except RuntimeError as error:
raise error
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims)
y1 = torch.randint(0, height // 2, extra_dims)
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, extra_dims)
y = torch.randint(0, height // 2, extra_dims)
w = randint_with_tensor_bounds(1, width - x)
h = randint_with_tensor_bounds(1, height - y)
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
raise ValueError()
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size)
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY)
def make_bounding_boxes(
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH),
image_sizes=((32, 32),),
dtypes=(torch.int64, torch.float32),
extra_dims=((4,), (2, 3)),
):
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes):
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype)
for format, extra_dims_ in itertools.product(formats, extra_dims):
yield make_bounding_box(format=format, extra_dims=extra_dims_)
class SampleInput:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class KernelInfo:
def __init__(self, name, *, sample_inputs_fn):
self.name = name
self.kernel = getattr(K, name)
self._sample_inputs_fn = sample_inputs_fn
def sample_inputs(self):
yield from self._sample_inputs_fn()
def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
sample_input = args[0]
return self.kernel(*sample_input.args, **sample_input.kwargs)
return self.kernel(*args, **kwargs)
KERNEL_INFOS = []
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
return sample_inputs_fn
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image():
for image in make_images():
yield SampleInput(image)
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
def resize_image():
for image, interpolation in itertools.product(
make_images(),
[
K.InterpolationMode.BILINEAR,
K.InterpolationMode.NEAREST,
],
):
height, width = image.shape[-2:]
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield SampleInput(image, size=size, interpolation=interpolation)
@register_kernel_info_from_sample_inputs_fn
def resize_bounding_box():
for bounding_box in make_bounding_boxes():
height, width = bounding_box.image_size
for size in [
(height, width),
(int(height * 0.75), int(width * 1.25)),
]:
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
class TestKernelsCommon:
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name)
def test_scriptable(self, kernel_info):
jit.script(kernel_info.kernel)
@pytest.mark.parametrize(
("kernel_info", "sample_input"),
[
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}")
for kernel_info in KERNEL_INFOS
for idx, sample_input in enumerate(kernel_info.sample_inputs())
],
)
def test_eager_vs_scripted(self, kernel_info, sample_input):
eager = kernel_info(sample_input)
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs)
torch.testing.assert_close(eager, scripted)
......@@ -7,9 +7,9 @@ except (ModuleNotFoundError, TypeError) as error:
"Note that you cannot install it with `pip install torchdata`, since this is another package."
) from error
from . import decoder, utils
from . import utils
from ._home import home
# Load this last, since some parts depend on the above being loaded first
from ._api import register, list_datasets, info, load # usort: skip
from ._api import list_datasets, info, load # usort: skip
from ._folder import from_data_folder, from_image_folder
import io
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.utils._internal import add_suggestion
from . import _builtin
......@@ -48,27 +45,15 @@ def info(name: str) -> DatasetInfo:
return find(name).info
DEFAULT_DECODER = object()
DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DatasetType.RAW: raw,
DatasetType.IMAGE: pil,
}
def load(
name: str,
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
skip_integrity_check: bool = False,
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)
if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
config = dataset.info.make_config(**options)
root = os.path.join(home(), dataset.name)
return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check)
......@@ -19,10 +19,8 @@ that module create a class that inherits from `datasets.utils.Dataset` and
overwrites at minimum three methods that will be discussed in detail below:
```python
import io
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List
import torch
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource
......@@ -34,11 +32,7 @@ class MyDataset(Dataset):
...
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
...
```
......@@ -49,10 +43,6 @@ The `DatasetInfo` carries static information about the dataset. There are two
required fields:
- `name`: Name of the dataset. This will be used to load the dataset with
`datasets.load(name)`. Should only contain lowercase characters.
- `type`: Field of the `datasets.utils.DatasetType` enum. This is used to select
the default decoder in case the user doesn't pass one. There are currently
only two options: `IMAGE` and `RAW` ([see
below](#what-is-the-datasettyperaw-and-when-do-i-use-it) for details).
There are more optional parameters that can be passed:
......@@ -105,7 +95,7 @@ def sha256sum(path, chunk_size=1024 * 1024):
print(checksum.hexdigest())
```
### `_make_datapipe(resource_dps, *, config, decoder)`
### `_make_datapipe(resource_dps, *, config)`
This method is the heart of the dataset, where we transform the raw data into
a usable form. A major difference compared to the current stable datasets is
......@@ -178,28 +168,6 @@ contains. You can also do that with `resources_dp[1]` or `resources_dp[2]`
(etc.) if they exist. Then follow the instructions above to manipulate these
datapipes and return the appropriate dictionary format.
### What is the `DatasetType.RAW` and when do I use it?
`DatasetType.RAW` marks dataset that provides decoded, i.e. raw pixel values,
rather than encoded image files such as `.jpg` or `.png`. This is usually only
the case for small datasets, since it requires a lot more disk space. The
default decoder `datasets.decoder.raw` is only a sentinel and should not be
called directly. The decoding should look something like
```python
from torchvision.prototype.datasets.decoder import raw
image = ...
if decoder is raw:
image = Image(image)
else:
image_buffer = image_buffer_from_raw(image)
image = decoder(image_buffer) if decoder else image_buffer
```
For examples, have a look at the MNIST, CIFAR, or SEMEION datasets.
### How do I handle a dataset that defines many categories?
As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only
......
import functools
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple, BinaryIO
import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
......@@ -18,17 +15,15 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, Feature
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
class Caltech101(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech101",
type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
......@@ -81,33 +76,26 @@ class Caltech101(Dataset):
return category, id
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
def _prepare_sample(
self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]]
) -> Dict[str, Any]:
key, (image_data, ann_data) = data
category, _ = key
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data
label = self.info.categories.index(category)
image = decoder(image_buffer) if decoder else image_buffer
image = EncodedImage.from_file(image_buffer)
ann = read_mat(ann_buffer)
bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy")
contour = Feature(ann["obj_contour"].T)
return dict(
category=category,
label=label,
image=image,
label=Label.from_category(category, categories=self.categories),
image_path=image_path,
bbox=bbox,
contour=contour,
image=image,
ann_path=ann_path,
bounding_box=BoundingBox(
ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size
),
contour=_Feature(ann["obj_contour"].T),
)
def _make_datapipe(
......@@ -115,7 +103,6 @@ class Caltech101(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps
......@@ -133,7 +120,7 @@ class Caltech101(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
......@@ -148,7 +135,6 @@ class Caltech256(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech256",
type=DatasetType.IMAGE,
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)
......@@ -164,32 +150,26 @@ class Caltech256(Dataset):
path = pathlib.Path(data[0])
return path.name != "RENAME2"
def _collate_and_decode_sample(
self,
data: Tuple[str, io.IOBase],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
dir_name = pathlib.Path(path).parent.name
label_str, category = dir_name.split(".")
label = Label(int(label_str), category=category)
return dict(label=label, image=decoder(buffer) if decoder else buffer)
return dict(
path=path,
image=EncodedImage.from_file(buffer),
label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
......
import csv
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
......@@ -17,7 +15,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
GDriveResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -26,7 +23,8 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Feature, Label, BoundingBox
from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
......@@ -34,7 +32,7 @@ csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
def __init__(
self,
datapipe: IterDataPipe[Tuple[Any, io.IOBase]],
datapipe: IterDataPipe[Tuple[Any, BinaryIO]],
*,
fieldnames: Optional[Sequence[str]] = None,
) -> None:
......@@ -66,7 +64,6 @@ class CelebA(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"celeba",
type=DatasetType.IMAGE,
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
valid_options=dict(split=("train", "val", "test")),
)
......@@ -92,7 +89,7 @@ class CelebA(Dataset):
sha256="f0e5da289d5ccf75ffe8811132694922b60f2af59256ed362afa03fefba324d0",
file_name="list_attr_celeba.txt",
)
bboxes = GDriveResource(
bounding_boxes = GDriveResource(
"0B7EVK8r0v71pbThiMVRxWXZ4dU0",
sha256="7487a82e57c4bb956c5445ae2df4a91ffa717e903c5fa22874ede0820c8ec41b",
file_name="list_bbox_celeba.txt",
......@@ -102,7 +99,7 @@ class CelebA(Dataset):
sha256="6c02a87569907f6db2ba99019085697596730e8129f67a3d61659f198c48d43b",
file_name="list_landmarks_align_celeba.txt",
)
return [splits, images, identities, attributes, bboxes, landmarks]
return [splits, images, identities, attributes, bounding_boxes, landmarks]
_SPLIT_ID_TO_NAME = {
"0": "train",
......@@ -113,38 +110,39 @@ class CelebA(Dataset):
def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool:
return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split
def _collate_anns(self, data: Tuple[Tuple[str, Dict[str, str]], ...]) -> Tuple[str, Dict[str, Dict[str, str]]]:
(image_id, identity), (_, attributes), (_, bbox), (_, landmarks) = data
return image_id, dict(identity=identity, attributes=attributes, bbox=bbox, landmarks=landmarks)
def _collate_and_decode_sample(
def _prepare_sample(
self,
data: Tuple[Tuple[str, Tuple[Tuple[str, Dict[str, Any]], Tuple[str, io.IOBase]]], Tuple[str, Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
data: Tuple[
Tuple[str, Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]],
Tuple[
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
Tuple[str, Dict[str, str]],
],
],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, (_, image_data) = split_and_image_data
path, buffer = image_data
_, ann = ann_data
image = decoder(buffer) if decoder else buffer
identity = Label(int(ann["identity"]["identity"]))
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = BoundingBox([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = {
landmark: Feature((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
for landmark in {key[:-2] for key in ann["landmarks"].keys()}
}
image = EncodedImage.from_file(buffer)
(_, identity), (_, attributes), (_, bounding_box), (_, landmarks) = ann_data
return dict(
path=path,
image=image,
identity=identity,
attributes=attributes,
bbox=bbox,
landmarks=landmarks,
identity=Label(int(identity["identity"])),
attributes={attr: value == "1" for attr, value in attributes.items()},
bounding_box=BoundingBox(
[int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")],
format="xywh",
image_size=image.image_size,
),
landmarks={
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
)
def _make_datapipe(
......@@ -152,9 +150,8 @@ class CelebA(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
......@@ -167,12 +164,11 @@ class CelebA(Dataset):
for dp, fieldnames in (
(identities_dp, ("image_id", "identity")),
(attributes_dp, None),
(bboxes_dp, None),
(bounding_boxes_dp, None),
(landmarks_dp, None),
)
]
)
anns_dp = Mapper(anns_dp, self._collate_anns)
dp = IterKeyZipper(
splits_dp,
......@@ -182,5 +178,11 @@ class CelebA(Dataset):
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
dp = IterKeyZipper(
dp,
anns_dp,
key_fn=getitem(0),
ref_key_fn=getitem(0, 0),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
......@@ -3,34 +3,28 @@ import functools
import io
import pathlib
import pickle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator, cast
from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO
import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Filter,
Mapper,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
image_buffer_from_array,
path_comparator,
hint_sharding,
)
from torchvision.prototype.features import Label, Image
__all__ = ["Cifar10", "Cifar100"]
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
......@@ -52,13 +46,12 @@ class _CifarBase(Dataset):
_CATEGORIES_KEY: str
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, io.IOBase], *, split: str) -> Optional[int]:
def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]:
pass
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
type(self).__name__.lower(),
type=DatasetType.RAW,
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(split=("train", "test")),
)
......@@ -75,31 +68,18 @@ class _CifarBase(Dataset):
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
def _collate_and_decode(
self,
data: Tuple[np.ndarray, int],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
image_array, category_idx = data
image: Union[Image, io.BytesIO]
if decoder is raw:
image = Image(image_array)
else:
image_buffer = image_buffer_from_array(image_array.transpose((1, 2, 0)))
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = Label(category_idx, category=self.categories[category_idx])
return dict(image=image, label=label)
return dict(
image=Image(image_array),
label=Label(category_idx, categories=self.categories),
)
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
......@@ -107,7 +87,7 @@ class _CifarBase(Dataset):
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder))
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
......
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
from torchvision.prototype.datasets.utils import (
Dataset,
......@@ -11,7 +8,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -21,14 +17,13 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor,
getitem,
)
from torchvision.prototype.features import Label
from torchvision.prototype.features import Label, EncodedImage
class CLEVR(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"clevr",
type=DatasetType.IMAGE,
homepage="https://cs.stanford.edu/people/jcjohns/clevr/",
valid_options=dict(split=("train", "val", "test")),
)
......@@ -53,21 +48,16 @@ class CLEVR(Dataset):
key, _ = data
return key == "scenes"
def _add_empty_anns(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[str, io.IOBase], None]:
def _add_empty_anns(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[str, BinaryIO], None]:
return data, None
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, io.IOBase], Optional[Dict[str, Any]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, Any]]]) -> Dict[str, Any]:
image_data, scenes_data = data
path, buffer = image_data
return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
image=EncodedImage.from_file(buffer),
label=Label(len(scenes_data["objects"])) if scenes_data else None,
)
......@@ -76,7 +66,6 @@ class CLEVR(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, scenes_dp = Demultiplexer(
......@@ -107,4 +96,4 @@ class CLEVR(Dataset):
else:
dp = Mapper(images_dp, self._add_empty_anns)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
import functools
import io
import pathlib
import re
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
import torch
from torchdata.datapipes.iter import (
......@@ -22,7 +21,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
MappingIterator,
......@@ -33,7 +31,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -44,7 +42,6 @@ class Coco(Dataset):
return DatasetInfo(
name,
type=DatasetType.IMAGE,
dependencies=("pycocotools",),
categories=categories,
homepage="https://cocodataset.org/",
......@@ -96,10 +93,9 @@ class Coco(Dataset):
def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
image_size = (image_meta["height"], image_meta["width"])
labels = [ann["category_id"] for ann in anns]
categories = [self.info.categories[label] for label in labels]
return dict(
# TODO: create a segmentation feature
segmentations=Feature(
segmentations=_Feature(
torch.stack(
[
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size)
......@@ -107,16 +103,17 @@ class Coco(Dataset):
]
)
),
areas=Feature([ann["area"] for ann in anns]),
crowds=Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
areas=_Feature([ann["area"] for ann in anns]),
crowds=_Feature([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
image_size=image_size,
),
labels=Label(labels),
categories=categories,
super_categories=[self.info.extra.category_to_super_category[category] for category in categories],
labels=Label(labels, categories=self.categories),
super_categories=[
self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels
],
ann_ids=[ann["id"] for ann in anns],
)
......@@ -150,26 +147,24 @@ class Coco(Dataset):
else:
return None
def _collate_and_decode_image(
self, data: Tuple[str, io.IOBase], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]:
def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data
return dict(path=path, image=decoder(buffer) if decoder else buffer)
return dict(
path=path,
image=EncodedImage.from_file(buffer),
)
def _collate_and_decode_sample(
def _prepare_sample(
self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, io.IOBase]],
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
*,
annotations: Optional[str],
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
annotations: str,
) -> Dict[str, Any]:
ann_data, image_data = data
anns, image_meta = ann_data
sample = self._collate_and_decode_image(image_data, decoder=decoder)
if annotations:
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
sample = self._prepare_image(image_data)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample
def _make_datapipe(
......@@ -177,14 +172,13 @@ class Coco(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps
if config.annotations is None:
dp = hint_sharding(images_dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder))
return Mapper(dp, self._prepare_image)
meta_dp = Filter(
meta_dp,
......@@ -230,9 +224,8 @@ class Coco(Dataset):
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(
dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder)
)
return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations))
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
config = self.default_config
......
import csv
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
......@@ -21,7 +19,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -32,7 +29,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
path_accessor,
)
from torchvision.prototype.features import Label, BoundingBox, Feature
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage
csv.register_dialect("cub200", delimiter=" ")
......@@ -41,7 +38,6 @@ class CUB200(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"cub200",
type=DatasetType.IMAGE,
homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html",
dependencies=("scipy",),
valid_options=dict(
......@@ -105,58 +101,55 @@ class CUB200(Dataset):
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name
def _2011_load_ann(
self,
data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
def _2011_prepare_ann(
self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int]
) -> Dict[str, Any]:
_, (bounding_box_data, segmentation_data) = data
segmentation_path, segmentation_buffer = segmentation_data
return dict(
bounding_box=BoundingBox([float(part) for part in bounding_box_data[1:]], format="xywh"),
bounding_box=BoundingBox(
[float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size
),
segmentation_path=segmentation_path,
segmentation=Feature(decoder(segmentation_buffer)) if decoder else segmentation_buffer,
segmentation=EncodedImage.from_file(segmentation_buffer),
)
def _2010_split_key(self, data: str) -> str:
return data.rsplit("/", maxsplit=1)[1]
def _2010_anns_key(self, data: Tuple[str, io.IOBase]) -> Tuple[str, Tuple[str, io.IOBase]]:
def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name, data
def _2010_load_ann(
self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]:
def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]:
_, (path, buffer) = data
content = read_mat(buffer)
return dict(
ann_path=path,
bounding_box=BoundingBox(
[int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], format="xyxy"
[int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")],
format="xyxy",
image_size=image_size,
),
segmentation=Feature(content["seg"]),
segmentation=_Feature(content["seg"]),
)
def _collate_and_decode_sample(
def _prepare_sample(
self,
data: Tuple[Tuple[str, Tuple[str, io.IOBase]], Any],
data: Tuple[Tuple[str, Tuple[str, BinaryIO]], Any],
*,
year: str,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
prepare_ann_fn: Callable[[Any, Tuple[int, int]], Dict[str, Any]],
) -> Dict[str, Any]:
data, anns_data = data
_, image_data = data
path, buffer = image_data
dir_name = pathlib.Path(path).parent.name
label_str, category = dir_name.split(".")
image = EncodedImage.from_file(buffer)
return dict(
(self._2011_load_ann if year == "2011" else self._2010_load_ann)(anns_data, decoder=decoder),
image=decoder(buffer) if decoder else buffer,
label=Label(int(label_str), category=category),
prepare_ann_fn(anns_data, image.image_size),
image=image,
label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories),
)
def _make_datapipe(
......@@ -164,8 +157,8 @@ class CUB200(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
prepare_ann_fn: Callable
if config.year == "2011":
archive_dp, segmentations_dp = resource_dps
images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer(
......@@ -193,6 +186,8 @@ class CUB200(Dataset):
keep_key=True,
buffer_size=INFINITE_BUFFER_SIZE,
)
prepare_ann_fn = self._2011_prepare_ann
else: # config.year == "2010"
split_dp, images_dp, anns_dp = resource_dps
......@@ -202,6 +197,8 @@ class CUB200(Dataset):
anns_dp = Mapper(anns_dp, self._2010_anns_key)
prepare_ann_fn = self._2010_prepare_ann
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
......@@ -218,7 +215,7 @@ class CUB200(Dataset):
getitem(0),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, year=config.year, decoder=decoder))
return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn))
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(year="2011")
......
import enum
import functools
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, BinaryIO
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
......@@ -21,7 +18,6 @@ from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -29,7 +25,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
getitem,
)
from torchvision.prototype.features import Label
from torchvision.prototype.features import Label, EncodedImage
class DTDDemux(enum.IntEnum):
......@@ -42,7 +38,6 @@ class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"dtd",
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/",
valid_options=dict(
split=("train", "test", "val"),
......@@ -75,12 +70,7 @@ class DTD(Dataset):
# The split files contain hardcoded posix paths for the images, e.g. banded/banded_0001.jpg
return str(path.relative_to(path.parents[1]).as_posix())
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, List[str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
(_, joint_categories_data), image_data = data
_, *joint_categories = joint_categories_data
path, buffer = image_data
......@@ -89,9 +79,9 @@ class DTD(Dataset):
return dict(
joint_categories={category for category in joint_categories if category},
label=Label(self.info.categories.index(category), category=category),
label=Label.from_category(category, categories=self.categories),
path=path,
image=decoder(buffer) if decoder else buffer,
image=EncodedImage.from_file(buffer),
)
def _make_datapipe(
......@@ -99,7 +89,6 @@ class DTD(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
......@@ -128,7 +117,7 @@ class DTD(Dataset):
ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == DTDDemux.IMAGES
......
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, cast
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
KaggleDownloadResource,
)
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
image_buffer_from_array,
)
from torchvision.prototype.features import Label, Image
......@@ -25,7 +20,6 @@ class FER2013(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fer2013",
type=DatasetType.RAW,
homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge",
categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"),
valid_options=dict(split=("train", "test")),
......@@ -44,26 +38,12 @@ class FER2013(Dataset):
)
return [archive]
def _collate_and_decode_sample(
self,
data: Dict[str, Any],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
raw_image = torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)
def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
label_id = data.get("emotion")
label_idx = int(label_id) if label_id is not None else None
image: Union[Image, io.BytesIO]
if decoder is raw:
image = Image(raw_image)
else:
image_buffer = image_buffer_from_array(raw_image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
return dict(
image=image,
label=Label(label_idx, category=self.info.categories[label_idx]) if label_idx is not None else None,
image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)),
label=Label(int(label_id), categories=self.categories) if label_id is not None else None,
)
def _make_datapipe(
......@@ -71,10 +51,9 @@ class FER2013(Dataset):
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = CSVDictParser(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
import io
import pathlib
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
DatasetType,
HttpResource,
)
from torchvision.prototype.datasets.utils._internal import (
......@@ -19,14 +15,13 @@ from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
INFINITE_BUFFER_SIZE,
)
from torchvision.prototype.features import Label, BoundingBox
from torchvision.prototype.features import Label, BoundingBox, EncodedImage
class GTSRB(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"gtsrb",
type=DatasetType.IMAGE,
homepage="https://benchmark.ini.rub.de",
categories=[f"{label:05d}" for label in range(43)],
valid_options=dict(split=("train", "test")),
......@@ -66,33 +61,26 @@ class GTSRB(Dataset):
else:
return None
def _collate_and_decode(
self, data: Tuple[Tuple[str, Any], Dict[str, Any]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]:
(image_path, image_buffer), csv_info = data
def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[str, Any]:
(path, buffer), csv_info = data
label = int(csv_info["ClassId"])
bbox = BoundingBox(
torch.tensor([int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")]),
bounding_box = BoundingBox(
[int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")],
format="xyxy",
image_size=(int(csv_info["Height"]), int(csv_info["Width"])),
)
return {
"image_path": image_path,
"image": decoder(image_buffer) if decoder else image_buffer,
"label": Label(label, category=self.categories[label]),
"bbox": bbox,
"path": path,
"image": EncodedImage.from_file(buffer),
"label": Label(label, categories=self.categories),
"bounding_box": bounding_box,
}
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
if config.split == "train":
images_dp, ann_dp = Demultiplexer(
resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
......@@ -101,13 +89,12 @@ class GTSRB(Dataset):
images_dp, ann_dp = resource_dps
images_dp = Filter(images_dp, path_comparator("suffix", ".ppm"))
# The order of the image files in the the .zip archives perfectly match the order of the entries in
# the (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper.
# The order of the image files in the .zip archives perfectly match the order of the entries in the
# (possibly concatenated) .csv files. So we're able to use Zipper here instead of a IterKeyZipper.
ann_dp = CSVDictParser(ann_dp, delimiter=";")
dp = Zipper(images_dp, ann_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, partial(self._collate_and_decode, decoder=decoder))
return dp
return Mapper(dp, self._prepare_sample)
import functools
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast
import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer
from torchdata.datapipes.iter import TarArchiveReader
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
OnlineResource,
ManualDownloadResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
......@@ -24,7 +22,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Label
from torchvision.prototype.features import Label, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -40,7 +38,6 @@ class ImageNet(Dataset):
return DatasetInfo(
name,
type=DatasetType.IMAGE,
dependencies=("scipy",),
categories=categories,
homepage="https://www.image-net.org/",
......@@ -61,14 +58,6 @@ class ImageNet(Dataset):
def supports_sharded(self) -> bool:
return True
@property
def category_to_wnid(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.category_to_wnid)
@property
def wnid_to_category(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.wnid_to_category)
_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
......@@ -77,23 +66,56 @@ class ImageNet(Dataset):
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
name = "test_v10102019" if config.split == "test" else config.split
images = ImageNetResource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
devkit = ImageNetResource(
file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
images = ImageNetResource(
file_name=f"ILSVRC2012_img_{name}.tar",
sha256=self._IMAGES_CHECKSUMS[name],
)
resources: List[OnlineResource] = [images]
if config.split == "val":
devkit = ImageNetResource(
file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
)
resources.append(devkit)
return resources
return [images, devkit]
def num_samples(self, config: DatasetConfig) -> int:
return {
"train": 1_281_167,
"val": 50_000,
"test": 100_000,
}[config.split]
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]:
def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
path = pathlib.Path(data[0])
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
category = self.wnid_to_category[wnid]
label_data = (Label(self.categories.index(category)), category, wnid)
return label_data, data
wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"]
label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories)
return (label, wnid), data
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
return None, data
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
return {
"meta.mat": 0,
"ILSVRC2012_validation_ground_truth.txt": 1,
}.get(pathlib.Path(data[0]).name)
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
return [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str:
return wnids[int(imagenet_label) - 1]
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
......@@ -101,72 +123,65 @@ class ImageNet(Dataset):
path = pathlib.Path(data[0])
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]:
def _prepare_val_data(
self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]]
) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
label_data, image_data = data
_, label = label_data
category = self.categories[label]
wnid = self.category_to_wnid[category]
return (Label(label), category, wnid), image_data
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]:
return None, data
_, wnid = label_data
label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories)
return (label, wnid), image_data
def _collate_and_decode_sample(
def _prepare_sample(
self,
data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
) -> Dict[str, Any]:
label_data, (path, buffer) = data
sample = dict(
return dict(
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
path=path,
image=decoder(buffer) if decoder else buffer,
image=EncodedImage.from_file(buffer),
)
if label_data:
sample.update(dict(zip(("label", "category", "wnid"), label_data)))
return sample
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
images_dp, devkit_dp = resource_dps
if config.split in {"train", "test"}:
dp = resource_dps[0]
if config.split == "train":
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
if config.split == "train":
dp = TarArchiveReader(dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_train_data)
elif config.split == "val":
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = hint_sharding(devkit_dp)
devkit_dp = hint_shuffling(devkit_dp)
dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data)
else: # config.split == "val":
images_dp, devkit_dp = resource_dps
meta_dp, label_dp = Demultiplexer(
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
_, wnids = zip(*next(iter(meta_dp)))
label_dp = LineReader(label_dp, decode=True, return_path=False)
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids))
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_sharding(label_dp)
label_dp = hint_shuffling(label_dp)
dp = IterKeyZipper(
devkit_dp,
label_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._val_test_image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = hint_sharding(images_dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data)
dp = Mapper(dp, self._prepare_val_data)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
......@@ -176,22 +191,13 @@ class ImageNet(Dataset):
}
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config)
config = self.info.make_config(split="val")
resources = self.resources(config)
devkit_dp = resources[1].load(root)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
meta = next(iter(devkit_dp))[1]
synsets = read_mat(meta, squeeze_me=True)["synsets"]
categories_and_wnids = cast(
List[Tuple[str, ...]],
[
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
],
)
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment