Unverified Commit e836b3d8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

simplify Feature implementation (#5539)

* simplify Feature implementation

* fix mypy
parent 97385df0
import enum
from typing import TypeVar, Type
T = TypeVar("T", bound=enum.Enum)
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def from_str(self, member: str):
def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
try:
return self[member]
except KeyError:
......
......@@ -22,20 +22,40 @@ class BoundingBox(_Feature):
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
bounding_box.format = format
bounding_box._metadata.update(dict(format=format, image_size=image_size))
bounding_box.image_size = image_size
return bounding_box
@classmethod
def new_like(
cls,
other: BoundingBox,
data: Any,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
image_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
) -> BoundingBox:
return super().new_like(
other,
data,
format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size,
**kwargs,
)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
......
from __future__ import annotations
import os
import sys
from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any
......@@ -13,19 +15,25 @@ D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature):
@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> EncodedData:
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super()._to_tensor(data, dtype=dtype, device=device)
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
@classmethod
def from_file(cls: Type[D], file: BinaryIO) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder))
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **kwargs)
@classmethod
def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D:
def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D:
with open(path, "rb") as file:
return cls.from_file(file)
return cls.from_file(file, **kwargs)
class EncodedImage(EncodedData):
......
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
import torch
from torch._C import _TensorBase, DisableTorchFunction
......@@ -8,59 +8,22 @@ F = TypeVar("F", bound="_Feature")
class _Feature(torch.Tensor):
_META_ATTRS: Set[str] = set()
_metadata: Dict[str, Any]
def __init_subclass__(cls) -> None:
"""
For convenient copying of metadata, we store it inside a dictionary rather than multiple individual attributes.
By adding the metadata attributes as class annotations on subclasses of :class:`Feature`, this method adds
properties to have the same convenient access as regular attributes.
>>> class Foo(_Feature):
... bar: str
... baz: Optional[str]
>>> foo = Foo()
>>> foo.bar
>>> foo.baz
This has the additional benefit that autocomplete engines and static type checkers are aware of the metadata.
"""
meta_attrs = {attr for attr in cls.__annotations__.keys() - cls.__dict__.keys() if not attr.startswith("_")}
for super_cls in cls.__mro__[1:]:
if super_cls is _Feature:
break
meta_attrs.update(cast(Type[_Feature], super_cls)._META_ATTRS)
cls._META_ATTRS = meta_attrs
for name in meta_attrs:
setattr(cls, name, property(cast(Callable[[F], Any], lambda self, name=name: self._metadata[name])))
def __new__(
cls: Type[F],
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> F:
if isinstance(device, str):
device = torch.device(device)
feature = cast(
return cast(
F,
torch.Tensor._make_subclass(
cast(_TensorBase, cls),
cls._to_tensor(data, dtype=dtype, device=device),
# requires_grad
False,
torch.as_tensor(data, dtype=dtype, device=device), # type: ignore[arg-type]
requires_grad,
),
)
feature._metadata = dict()
return feature
@classmethod
def _to_tensor(self, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
return torch.as_tensor(data, dtype=dtype, device=device)
@classmethod
def new_like(
......@@ -69,12 +32,17 @@ class _Feature(torch.Tensor):
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
**metadata: Any,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
**kwargs: Any,
) -> F:
_metadata = other._metadata.copy()
_metadata.update(metadata)
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata)
return cls(
data,
dtype=dtype if dtype is not None else other.dtype,
device=device if device is not None else other.device,
requires_grad=requires_grad if requires_grad is not None else other.requires_grad,
**kwargs,
)
@classmethod
def __torch_function__(
......
......@@ -26,11 +26,17 @@ class Image(_Feature):
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Image:
image = super().__new__(cls, data, dtype=dtype, device=device)
data = torch.as_tensor(data, dtype=dtype, device=device) # type: ignore[arg-type]
if data.ndim < 2:
raise ValueError
elif data.ndim == 2:
data = data.unsqueeze(0)
image = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None:
color_space = cls.guess_color_space(image)
......@@ -38,19 +44,19 @@ class Image(_Feature):
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
image._metadata.update(dict(color_space=color_space))
elif not isinstance(color_space, ColorSpace):
raise ValueError
image.color_space = color_space
return image
@classmethod
def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor:
tensor = super()._to_tensor(data, dtype=dtype, device=device)
if tensor.ndim < 2:
raise ValueError
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
return tensor
def new_like(
cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any
) -> Image:
return super().new_like(
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
)
@property
def image_size(self) -> Tuple[int, int]:
......
from __future__ import annotations
from typing import Any, Optional, Sequence, cast
from typing import Any, Optional, Sequence, cast, Union
import torch
from torchvision.prototype.utils._internal import apply_recursively
......@@ -15,20 +15,32 @@ class Label(_Feature):
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Label:
label = super().__new__(cls, data, dtype=dtype, device=device)
label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
label._metadata.update(dict(categories=categories))
label.categories = categories
return label
@classmethod
def from_category(cls, category: str, *, categories: Sequence[str]) -> Label:
return cls(categories.index(category), categories=categories)
def new_like(cls, other: Label, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> Label:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
@classmethod
def from_category(
cls,
category: str,
*,
categories: Sequence[str],
**kwargs: Any,
) -> Label:
return cls(categories.index(category), categories=categories, **kwargs)
def to_categories(self) -> Any:
if not self.categories:
......@@ -44,16 +56,24 @@ class OneHotLabel(_Feature):
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
like: Optional[Label] = None,
categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> OneHotLabel:
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device)
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
if categories is not None and len(categories) != one_hot_label.shape[-1]:
raise ValueError()
one_hot_label._metadata.update(dict(categories=categories))
one_hot_label.categories = categories
return one_hot_label
@classmethod
def new_like(
cls, other: OneHotLabel, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any
) -> OneHotLabel:
return super().new_like(
other, data, categories=categories if categories is not None else other.categories, **kwargs
)
......@@ -46,7 +46,7 @@ class Resize(Transform):
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=self.size)
return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size)))
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment