Unverified Commit fe78a8ae authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype features (#4721)

* add prototype features

* add some JIT tests

* refactor input data handling

* refactor tests

* cleanup tests

* add BoundingBox feature

* mypy

* xfail torchscript tests for now

* cleanup

* fix imports
parent 49ec677c
......@@ -15,7 +15,7 @@ from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import add_suggestion
from torchvision.prototype.utils._internal import add_suggestion
make_tensor = functools.partial(_make_tensor, device="cpu")
......
......@@ -5,7 +5,7 @@ import builtin_dataset_mocks
import pytest
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets.utils._internal import sequence_to_str
from torchvision.prototype.utils._internal import sequence_to_str
_loaders = []
......
import functools
import itertools
import pytest
import torch
from torch.testing import make_tensor as _make_tensor, assert_close
from torchvision.prototype import features
from torchvision.prototype.utils._internal import sequence_to_str
make_tensor = functools.partial(_make_tensor, device="cpu", dtype=torch.float32)
def make_bounding_box(*, format="xyxy", image_size=(10, 10)):
if isinstance(format, str):
format = features.BoundingBoxFormat[format]
height, width = image_size
if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, ())
y1 = torch.randint(0, height // 2, ())
x2 = torch.randint(int(x1) + 1, width - int(x1), ()) + x1
y2 = torch.randint(int(y1) + 1, height - int(y1), ()) + y1
parts = (x1, y1, x2, y2)
elif format == features.BoundingBoxFormat.XYWH:
x = torch.randint(0, width // 2, ())
y = torch.randint(0, height // 2, ())
w = torch.randint(1, width - int(x), ())
h = torch.randint(1, height - int(y), ())
parts = (x, y, w, h)
elif format == features.BoundingBoxFormat.CXCYWH:
cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ())
w = torch.randint(1, min(int(cx), width - int(cx)), ())
h = torch.randint(1, min(int(cy), height - int(cy)), ())
parts = (cx, cy, w, h)
else: # format == features.BoundingBoxFormat._SENTINEL:
parts = make_tensor((4,)).unbind()
return features.BoundingBox.from_parts(*parts, format=format, image_size=image_size)
MAKE_DATA_MAP = {
features.BoundingBox: make_bounding_box,
}
def make_feature(feature_type, **meta_data):
maker = MAKE_DATA_MAP.get(feature_type, lambda **meta_data: feature_type(make_tensor(()), **meta_data))
return maker(**meta_data)
class TestCommon:
FEATURE_TYPES, NON_DEFAULT_META_DATA = zip(
*(
(features.Image, dict(color_space=features.ColorSpace._SENTINEL)),
(features.Label, dict(category="category")),
(features.BoundingBox, dict(format=features.BoundingBoxFormat._SENTINEL, image_size=(-1, -1))),
)
)
feature_types = pytest.mark.parametrize(
"feature_type", FEATURE_TYPES, ids=lambda feature_type: feature_type.__name__
)
features = pytest.mark.parametrize(
"feature",
[
pytest.param(make_feature(feature_type, **meta_data), id=feature_type.__name__)
for feature_type, meta_data in zip(FEATURE_TYPES, NON_DEFAULT_META_DATA)
],
)
def test_consistency(self):
builtin_feature_types = {
name
for name, feature_type in features.__dict__.items()
if not name.startswith("_")
and isinstance(feature_type, type)
and issubclass(feature_type, features.Feature)
and feature_type is not features.Feature
}
untested_feature_types = builtin_feature_types - {feature_type.__name__ for feature_type in self.FEATURE_TYPES}
if untested_feature_types:
raise AssertionError(
f"The feature(s) {sequence_to_str(sorted(untested_feature_types), separate_last='and ')} "
f"is/are exposed at `torchvision.prototype.features`, but is/are not tested by `TestCommon`. "
f"Please add it/them to `TestCommon.FEATURE_TYPES`."
)
@features
def test_meta_data_attribute_access(self, feature):
for name, value in feature._meta_data.items():
assert getattr(feature, name) == feature._meta_data[name]
@feature_types
def test_torch_function(self, feature_type):
input = make_feature(feature_type)
# This can be any Tensor operation besides clone
output = input + 1
assert type(output) is torch.Tensor
assert_close(output, input + 1)
@feature_types
def test_clone(self, feature_type):
input = make_feature(feature_type)
output = input.clone()
assert type(output) is feature_type
assert_close(output, input)
assert output._meta_data == input._meta_data
@features
def test_serialization(self, tmpdir, feature):
file = tmpdir / "test_serialization.pt"
torch.save(feature, str(file))
loaded_feature = torch.load(str(file))
assert isinstance(loaded_feature, type(feature))
assert_close(loaded_feature, feature)
assert loaded_feature._meta_data == feature._meta_data
@features
def test_repr(self, feature):
assert type(feature).__name__ in repr(feature)
class TestBoundingBox:
@pytest.mark.parametrize(("format", "intermediate_format"), itertools.permutations(("xyxy", "xywh"), 2))
def test_cycle_consistency(self, format, intermediate_format):
input = make_bounding_box(format=format)
output = input.convert(intermediate_format).convert(format)
assert_close(input, output)
# For now, tensor subclasses with additional meta data do not work with torchscript.
# See https://github.com/pytorch/vision/pull/4721#discussion_r741676037.
@pytest.mark.xfail
class TestJit:
def test_bounding_box(self):
def resize(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
old_height, old_width = input.image_size
new_height, new_width = size
height_scale = new_height / old_height
width_scale = new_width / old_width
old_x1, old_y1, old_x2, old_y2 = input.convert("xyxy").to_parts()
new_x1 = old_x1 * width_scale
new_y1 = old_y1 * height_scale
new_x2 = old_x2 * width_scale
new_y2 = old_y2 * height_scale
return features.BoundingBox.from_parts(
new_x1, new_y1, new_x2, new_y2, like=input, format="xyxy", image_size=tuple(size.tolist())
)
def horizontal_flip(input: features.BoundingBox) -> features.BoundingBox:
x, y, w, h = input.convert("xywh").to_parts()
x = input.image_size[1] - (x + w)
return features.BoundingBox.from_parts(x, y, w, h, like=input, format="xywh")
def compose(input: features.BoundingBox, size: torch.Tensor) -> features.BoundingBox:
return horizontal_flip(resize(input, size)).convert("xyxy")
image_size = (8, 6)
input = features.BoundingBox([2, 4, 2, 4], format="cxcywh", image_size=image_size)
size = torch.tensor((4, 12))
expected = features.BoundingBox([6, 1, 10, 3], format="xyxy", image_size=image_size)
actual_eager = compose(input, size)
assert_close(actual_eager, expected)
sample_inputs = (features.BoundingBox(torch.zeros((4,)), image_size=(10, 10)), torch.tensor((20, 5)))
actual_jit = torch.jit.trace(compose, sample_inputs, check_trace=False)(input, size)
assert_close(actual_jit, expected)
from . import datasets
from . import features
from . import models
from . import transforms
from . import utils
......@@ -7,7 +7,7 @@ from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils._internal import add_suggestion
from torchvision.prototype.utils._internal import add_suggestion
from . import _builtin
......
......@@ -31,6 +31,7 @@ from torchvision.prototype.datasets.utils._internal import (
Decompressor,
INFINITE_BUFFER_SIZE,
)
from torchvision.prototype.features import Image, Label
__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
......@@ -126,15 +127,14 @@ class _MNISTBase(Dataset):
image, label = data
if decoder is raw:
image = image.unsqueeze(0)
image = Image(image)
else:
image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
category = self.info.categories[int(label)]
label = label.to(torch.int64)
label = Label(label, dtype=torch.int64, category=self.info.categories[int(label)])
return dict(image=image, category=category, label=label)
return dict(image=image, label=label)
def _make_datapipe(
self,
......
import io
from typing import cast
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms.functional import pil_to_tensor
__all__ = ["raw", "pil"]
......@@ -12,5 +12,5 @@ def raw(buffer: io.IOBase) -> torch.Tensor:
raise RuntimeError("This is just a sentinel and should never be called.")
def pil(buffer: io.IOBase, mode: str = "RGB") -> torch.Tensor:
return cast(torch.Tensor, pil_to_tensor(PIL.Image.open(buffer).convert(mode.upper())))
def pil(buffer: io.IOBase) -> features.Image:
return features.Image(pil_to_tensor(PIL.Image.open(buffer)))
......@@ -9,10 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets.utils._internal import (
add_suggestion,
sequence_to_str,
)
from torchvision.prototype.utils._internal import add_suggestion, sequence_to_str
from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
......
import collections.abc
import csv
import difflib
import enum
import gzip
import io
......@@ -11,7 +9,6 @@ import pathlib
import pickle
import textwrap
from typing import (
Collection,
Sequence,
Callable,
Union,
......@@ -41,8 +38,6 @@ from torchdata.datapipes.utils import StreamWrapper
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"sequence_to_str",
"add_suggestion",
"make_repr",
"FrozenMapping",
"FrozenBunch",
......@@ -67,33 +62,6 @@ INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'."""
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str:
def to_str(sep: str) -> str:
return sep.join([f"{key}={value}" for key, value in items])
......
from ._bounding_box import BoundingBoxFormat, BoundingBox
from ._feature import Feature
from ._image import Image, ColorSpace
from ._label import Label
import enum
import functools
from typing import Callable, Union, Tuple, Dict, Any, Optional, cast
import torch
from torchvision.prototype.utils._internal import StrEnum
from ._feature import Feature, DEFAULT
class BoundingBoxFormat(StrEnum):
# this is just for test purposes
_SENTINEL = -1
XYXY = enum.auto()
XYWH = enum.auto()
CXCYWH = enum.auto()
def to_parts(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return input.unbind(dim=-1) # type: ignore[return-value]
def from_parts(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
return torch.stack((a, b, c, d), dim=-1)
def format_converter_wrapper(
part_converter: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]
):
def wrapper(input: torch.Tensor) -> torch.Tensor:
return from_parts(*part_converter(*to_parts(input)))
return wrapper
@format_converter_wrapper
def xywh_to_xyxy(
x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = x
y1 = y
x2 = x + w
y2 = y + h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_xywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = x1
y = y1
w = x2 - x1
h = y2 - y1
return x, y, w, h
@format_converter_wrapper
def cxcywh_to_xyxy(
cx: torch.Tensor, cy: torch.Tensor, w: torch.Tensor, h: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
return x1, y1, x2, y2
@format_converter_wrapper
def xyxy_to_cxcywh(
x1: torch.Tensor, y1: torch.Tensor, x2: torch.Tensor, y2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return cx, cy, w, h
class BoundingBox(Feature):
formats = BoundingBoxFormat
format: BoundingBoxFormat
image_size: Tuple[int, int]
@classmethod
def _parse_meta_data(
cls,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
if isinstance(format, str):
format = BoundingBoxFormat[format]
format_fallback = BoundingBoxFormat.XYXY
return dict(
format=(format, format_fallback),
image_size=(image_size, functools.partial(cls.guess_image_size, format=format_fallback)),
)
_TO_XYXY_MAP = {
BoundingBoxFormat.XYWH: xywh_to_xyxy,
BoundingBoxFormat.CXCYWH: cxcywh_to_xyxy,
}
_FROM_XYXY_MAP = {
BoundingBoxFormat.XYWH: xyxy_to_xywh,
BoundingBoxFormat.CXCYWH: xyxy_to_cxcywh,
}
@classmethod
def guess_image_size(cls, data: torch.Tensor, *, format: BoundingBoxFormat) -> Tuple[int, int]:
if format not in (BoundingBoxFormat.XYWH, BoundingBoxFormat.CXCYWH):
if format != BoundingBoxFormat.XYXY:
data = cls._TO_XYXY_MAP[format](data)
data = cls._FROM_XYXY_MAP[BoundingBoxFormat.XYWH](data)
*_, w, h = to_parts(data)
return int(h.ceil()), int(w.ceil())
@classmethod
def from_parts(
cls,
a,
b,
c,
d,
*,
like: Optional["BoundingBox"] = None,
format: Union[str, BoundingBoxFormat] = DEFAULT, # type: ignore[assignment]
image_size: Optional[Tuple[int, int]] = DEFAULT, # type: ignore[assignment]
) -> "BoundingBox":
return cls(from_parts(a, b, c, d), like=like, image_size=image_size, format=format)
def to_parts(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return to_parts(self)
def convert(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
if isinstance(format, str):
format = BoundingBoxFormat[format]
if format == self.format:
return cast(BoundingBox, self.clone())
data = self
if self.format != BoundingBoxFormat.XYXY:
data = self._TO_XYXY_MAP[self.format](data)
if format != BoundingBoxFormat.XYXY:
data = self._FROM_XYXY_MAP[format](data)
return BoundingBox(data, like=self, format=format)
from typing import Tuple, cast, TypeVar, Set, Dict, Any, Callable, Optional, Mapping, Type, Sequence
import torch
from torch._C import _TensorBase, DisableTorchFunction
from torchvision.prototype.utils._internal import add_suggestion
F = TypeVar("F", bound="Feature")
DEFAULT = object()
class Feature(torch.Tensor):
_META_ATTRS: Set[str]
_meta_data: Dict[str, Any]
def __init_subclass__(cls):
if not hasattr(cls, "_META_ATTRS"):
cls._META_ATTRS = {
attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")
}
for attr in cls._META_ATTRS:
if not hasattr(cls, attr):
setattr(cls, attr, property(lambda self, attr=attr: self._meta_data[attr]))
def __new__(cls, data, *, dtype=None, device=None, like=None, **kwargs):
unknown_meta_attrs = kwargs.keys() - cls._META_ATTRS
if unknown_meta_attrs:
unknown_meta_attr = sorted(unknown_meta_attrs)[0]
raise TypeError(
add_suggestion(
f"{cls.__name__}() got unexpected keyword '{unknown_meta_attr}'.",
word=unknown_meta_attr,
possibilities=cls._META_ATTRS,
)
)
if like is not None:
dtype = dtype or like.dtype
device = device or like.device
data = cls._to_tensor(data, dtype=dtype, device=device)
requires_grad = False
self = torch.Tensor._make_subclass(cast(_TensorBase, cls), data, requires_grad)
meta_data = dict()
for attr, (explicit, fallback) in cls._parse_meta_data(**kwargs).items():
if explicit is not DEFAULT:
value = explicit
elif like is not None:
value = getattr(like, attr)
else:
value = fallback(data) if callable(fallback) else fallback
meta_data[attr] = value
self._meta_data = meta_data
return self
@classmethod
def _to_tensor(cls, data, *, dtype, device):
return torch.as_tensor(data, dtype=dtype, device=device)
@classmethod
def _parse_meta_data(cls) -> Dict[str, Tuple[Any, Any]]:
return dict()
@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
with DisableTorchFunction():
output = func(*args, **(kwargs or dict()))
if func is not torch.Tensor.clone:
return output
return cls(output, like=args[0])
def __repr__(self):
return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__)
from typing import Dict, Any, Union, Tuple
import torch
from torchvision.prototype.utils._internal import StrEnum
from ._feature import Feature, DEFAULT
class ColorSpace(StrEnum):
# this is just for test purposes
_SENTINEL = -1
OTHER = 0
GRAYSCALE = 1
RGB = 3
class Image(Feature):
color_spaces = ColorSpace
color_space: ColorSpace
@classmethod
def _parse_meta_data(
cls,
color_space: Union[str, ColorSpace] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
if isinstance(color_space, str):
color_space = ColorSpace[color_space]
return dict(color_space=(color_space, cls.guess_color_space))
@staticmethod
def guess_color_space(data: torch.Tensor) -> ColorSpace:
if data.ndim < 2:
return ColorSpace.OTHER
elif data.ndim == 2:
return ColorSpace.GRAYSCALE
num_channels = data.shape[-3]
if num_channels == 1:
return ColorSpace.GRAYSCALE
elif num_channels == 3:
return ColorSpace.RGB
else:
return ColorSpace.OTHER
from typing import Dict, Any, Optional, Tuple
from ._feature import Feature, DEFAULT
class Label(Feature):
category: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None))
import collections.abc
import difflib
import enum
from typing import Sequence, Collection, Callable
__all__ = ["StrEnum", "sequence_to_str", "add_suggestion"]
class StrEnumMeta(enum.EnumMeta):
def __getitem__(self, item):
return super().__getitem__(item.upper() if isinstance(item, str) else item)
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', {separate_last}'{seq[-1]}'."""
def add_suggestion(
msg: str,
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
) -> str:
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
if not hint:
return msg
return f"{msg.strip()} {hint}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment