Unverified Commit e836b3d8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify Feature implementation (#5539)

* simplify Feature implementation

* fix mypy
parent 97385df0
import enum import enum
from typing import TypeVar, Type
T = TypeVar("T", bound=enum.Enum)
class StrEnumMeta(enum.EnumMeta): class StrEnumMeta(enum.EnumMeta):
auto = enum.auto auto = enum.auto
def from_str(self, member: str): def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
try: try:
return self[member] return self[member]
except KeyError: except KeyError:
......
...@@ -22,20 +22,40 @@ class BoundingBox(_Feature): ...@@ -22,20 +22,40 @@ class BoundingBox(_Feature):
cls, cls,
data: Any, data: Any,
*, *,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
format: Union[BoundingBoxFormat, str], format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int], image_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> BoundingBox: ) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device) bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper()) format = BoundingBoxFormat.from_str(format.upper())
bounding_box.format = format
bounding_box._metadata.update(dict(format=format, image_size=image_size)) bounding_box.image_size = image_size
return bounding_box return bounding_box
@classmethod
def new_like(
cls,
other: BoundingBox,
data: Any,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
image_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
) -> BoundingBox:
return super().new_like(
other,
data,
format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size,
**kwargs,
)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state # promote this out of the prototype state
......
from __future__ import annotations
import os import os
import sys import sys
from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any
...@@ -13,19 +15,25 @@ D = TypeVar("D", bound="EncodedData") ...@@ -13,19 +15,25 @@ D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature): class EncodedData(_Feature):
@classmethod def __new__(
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super()._to_tensor(data, dtype=dtype, device=device) return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
@classmethod @classmethod
def from_file(cls: Type[D], file: BinaryIO) -> D: def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
@classmethod @classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D: def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
with open(path, "rb") as file: with open(path, "rb") as file:
return cls.from_file(file) return cls.from_file(file, **kwargs)
class EncodedImage(EncodedData): class EncodedImage(EncodedData):
......
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
import torch import torch
from torch._C import _TensorBase, DisableTorchFunction from torch._C import _TensorBase, DisableTorchFunction
...@@ -8,59 +8,22 @@ F = TypeVar("F", bound="_Feature") ...@@ -8,59 +8,22 @@ F = TypeVar("F", bound="_Feature")
class _Feature(torch.Tensor): class _Feature(torch.Tensor):
_META_ATTRS: Set[str] = set()
_metadata: Dict[str, Any]
def __init_subclass__(cls) -> None:
"""
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
properties to have the same convenient access as regular attributes.
>>> class Foo(_Feature):
... bar: str
... baz: Optional[str]
>>> foo = Foo()
>>> foo.bar
>>> foo.baz
This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
"""
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]:
if super_cls is _Feature:
break
meta_attrs.update(cast(Type[_Feature], super_cls)._META_ATTRS)
cls._META_ATTRS = meta_attrs
for name in meta_attrs:
setattr(cls, name, property(cast(Callable[[F], Any], lambda self, name=name: self._metadata[name])))
def __new__( def __new__(
cls: Type[F], cls: Type[F],
data: Any, data: Any,
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> F: ) -> F:
if isinstance(device, str): return cast(
device = torch.device(device)
feature = cast(
F, F,
torch.Tensor._make_subclass( torch.Tensor._make_subclass(
cast(_TensorBase, cls), cast(_TensorBase, cls),
cls._to_tensor(data, dtype=dtype, device=device), torch.as_tensor(data, dtype=dtype, device=device), # type: ignore[arg-type]
# requires_grad requires_grad,
False,
), ),
) )
feature._metadata = dict()
return feature
@classmethod
def _to_tensor(self, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device)
@classmethod @classmethod
def new_like( def new_like(
...@@ -69,12 +32,17 @@ class _Feature(torch.Tensor): ...@@ -69,12 +32,17 @@ class _Feature(torch.Tensor):
data: Any, data: Any,
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None, device: Optional[Union[torch.device, str, int]] = None,
**metadata: Any, requires_grad: Optional[bool] = None,
**kwargs: Any,
) -> F: ) -> F:
_metadata = other._metadata.copy() return cls(
_metadata.update(metadata) data,
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata) dtype=dtype if dtype is not None else other.dtype,
device=device if device is not None else other.device,
requires_grad=requires_grad if requires_grad is not None else other.requires_grad,
**kwargs,
)
@classmethod @classmethod
def __torch_function__( def __torch_function__(
......
...@@ -26,11 +26,17 @@ class Image(_Feature): ...@@ -26,11 +26,17 @@ class Image(_Feature):
cls, cls,
data: Any, data: Any,
*, *,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
color_space: Optional[Union[ColorSpace, str]] = None, color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Image: ) -> Image:
image = super().__new__(cls, data, dtype=dtype, device=device) data = torch.as_tensor(data, dtype=dtype, device=device) # type: ignore[arg-type]
if data.ndim < 2:
raise ValueError
elif data.ndim == 2:
data = data.unsqueeze(0)
image = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None: if color_space is None:
color_space = cls.guess_color_space(image) color_space = cls.guess_color_space(image)
...@@ -38,19 +44,19 @@ class Image(_Feature): ...@@ -38,19 +44,19 @@ class Image(_Feature):
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str): elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper()) color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
image._metadata.update(dict(color_space=color_space)) raise ValueError
image.color_space = color_space
return image return image
@classmethod @classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: def new_like(
tensor = super()._to_tensor(data, dtype=dtype, device=device) cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
if tensor.ndim < 2: ) -> Image:
raise ValueError return super().new_like(
elif tensor.ndim == 2: other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
tensor = tensor.unsqueeze(0) )
return tensor
@property @property
def image_size(self) -> Tuple[int, int]: def image_size(self) -> Tuple[int, int]:
......
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional, Sequence, cast from typing import Any, Optional, Sequence, cast, Union
import torch import torch
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
...@@ -15,20 +15,32 @@ class Label(_Feature): ...@@ -15,20 +15,32 @@ class Label(_Feature):
cls, cls,
data: Any, data: Any,
*, *,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None, categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Label: ) -> Label:
label = super().__new__(cls, data, dtype=dtype, device=device) label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
label._metadata.update(dict(categories=categories)) label.categories = categories
return label return label
@classmethod @classmethod
def from_category(cls, category: str, *, categories: Sequence[str]) -> Label: def new_like(cls, other: Label, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> Label:
return cls(categories.index(category), categories=categories) return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
@classmethod
def from_category(
cls,
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> Label:
return cls(categories.index(category), categories=categories, **kwargs)
def to_categories(self) -> Any: def to_categories(self) -> Any:
if not self.categories: if not self.categories:
...@@ -44,16 +56,24 @@ class OneHotLabel(_Feature): ...@@ -44,16 +56,24 @@ class OneHotLabel(_Feature):
cls, cls,
data: Any, data: Any,
*, *,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None, categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel: ) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device) one_hot_label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
if categories is not None and len(categories) != one_hot_label.shape[-1]: if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError() raise ValueError()
one_hot_label._metadata.update(dict(categories=categories)) one_hot_label.categories = categories
return one_hot_label return one_hot_label
@classmethod
def new_like(
cls, other: OneHotLabel, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any
) -> OneHotLabel:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
...@@ -46,7 +46,7 @@ class Resize(Transform): ...@@ -46,7 +46,7 @@ class Resize(Transform):
return features.SegmentationMask.new_like(input, output) return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox): elif isinstance(input, features.BoundingBox):
output = F.resize_bounding_box(input, self.size, image_size=input.image_size) output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=self.size) return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image): elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation) return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor): elif isinstance(input, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment