Unverified Commit 9b82df43 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove `_wrap()` class method from base class Datapoint (#7805)

parent 2030d208
...@@ -113,6 +113,26 @@ def test_detach_wrapping(): ...@@ -113,6 +113,26 @@ def test_detach_wrapping():
assert type(image_detached) is datapoints.Image assert type(image_detached) is datapoints.Image
def test_no_wrapping_exceptions_with_metadata():
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
bbox = bbox.clone()
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach()
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad
bbox.requires_grad_(True)
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
def test_other_op_no_wrapping(): def test_other_op_no_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16)) image = datapoints.Image(torch.rand(3, 16, 16))
......
...@@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint): ...@@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint):
canvas_size: Tuple[int, int] canvas_size: Tuple[int, int]
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls) bounding_boxes = tensor.as_subclass(cls)
bounding_boxes.format = format bounding_boxes.format = format
bounding_boxes.canvas_size = canvas_size bounding_boxes.canvas_size = canvas_size
...@@ -59,10 +61,6 @@ class BoundingBoxes(Datapoint): ...@@ -59,10 +61,6 @@ class BoundingBoxes(Datapoint):
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> BoundingBoxes: ) -> BoundingBoxes:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
return cls._wrap(tensor, format=format, canvas_size=canvas_size) return cls._wrap(tensor, format=format, canvas_size=canvas_size)
@classmethod @classmethod
...@@ -71,7 +69,7 @@ class BoundingBoxes(Datapoint): ...@@ -71,7 +69,7 @@ class BoundingBoxes(Datapoint):
other: BoundingBoxes, other: BoundingBoxes,
tensor: torch.Tensor, tensor: torch.Tensor,
*, *,
format: Optional[BoundingBoxFormat] = None, format: Optional[Union[BoundingBoxFormat, str]] = None,
canvas_size: Optional[Tuple[int, int]] = None, canvas_size: Optional[Tuple[int, int]] = None,
) -> BoundingBoxes: ) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference. """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
...@@ -85,9 +83,6 @@ class BoundingBoxes(Datapoint): ...@@ -85,9 +83,6 @@ class BoundingBoxes(Datapoint):
omitted, it is taken from the reference. omitted, it is taken from the reference.
""" """
if isinstance(format, str):
format = BoundingBoxFormat[format.upper()]
return cls._wrap( return cls._wrap(
tensor, tensor,
format=format if format is not None else other.format, format=format if format is not None else other.format,
......
...@@ -32,13 +32,9 @@ class Datapoint(torch.Tensor): ...@@ -32,13 +32,9 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
@classmethod @classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return cls._wrap(tensor) return tensor.as_subclass(cls)
_NO_WRAPPING_EXCEPTIONS = { _NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
......
...@@ -41,7 +41,7 @@ class Image(Datapoint): ...@@ -41,7 +41,7 @@ class Image(Datapoint):
elif tensor.ndim == 2: elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0) tensor = tensor.unsqueeze(0)
return cls._wrap(tensor) return tensor.as_subclass(cls)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
......
...@@ -36,4 +36,4 @@ class Mask(Datapoint): ...@@ -36,4 +36,4 @@ class Mask(Datapoint):
data = F.pil_to_tensor(data) data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor) return tensor.as_subclass(cls)
...@@ -31,7 +31,7 @@ class Video(Datapoint): ...@@ -31,7 +31,7 @@ class Video(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4: if data.ndim < 4:
raise ValueError raise ValueError
return cls._wrap(tensor) return tensor.as_subclass(cls)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
......
...@@ -15,7 +15,7 @@ class _LabelBase(Datapoint): ...@@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod @classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L: # type: ignore[override] def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
label_base = tensor.as_subclass(cls) label_base = tensor.as_subclass(cls)
label_base.categories = categories label_base.categories = categories
return label_base return label_base
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment