Unverified Commit d814772e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make datapoints deepcopyable (#7701)

parent 357a40f1
from copy import deepcopy
import pytest
import torch
from common_utils import assert_equal
from PIL import Image
from torchvision import datapoints
......@@ -30,3 +33,154 @@ def test_bbox_instance(data, format):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format
@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
([[[0.0, 1.0], [0.0, 1.0]]], None, False),
([[[0.0, 1.0], [0.0, 1.0]]], False, False),
([[[0.0, 1.0], [0.0, 1.0]]], True, True),
(torch.rand(3, 16, 16, requires_grad=False), None, False),
(torch.rand(3, 16, 16, requires_grad=False), False, False),
(torch.rand(3, 16, 16, requires_grad=False), True, True),
(torch.rand(3, 16, 16, requires_grad=True), None, True),
(torch.rand(3, 16, 16, requires_grad=True), False, False),
(torch.rand(3, 16, 16, requires_grad=True), True, True),
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Image(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad
def test_isinstance():
assert isinstance(datapoints.Image(torch.rand(3, 16, 16)), torch.Tensor)
def test_wrapping_no_copy():
tensor = torch.rand(3, 16, 16)
image = datapoints.Image(tensor)
assert image.data_ptr() == tensor.data_ptr()
def test_to_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))
image_to = image.to(torch.float64)
assert type(image_to) is datapoints.Image
assert image_to.dtype is torch.float64
def test_to_datapoint_reference():
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
image = datapoints.Image(tensor)
tensor_to = tensor.to(image)
assert type(tensor_to) is torch.Tensor
assert tensor_to.dtype is torch.float64
def test_clone_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))
image_clone = image.clone()
assert type(image_clone) is datapoints.Image
assert image_clone.data_ptr() != image.data_ptr()
def test_requires_grad__wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))
assert not image.requires_grad
image_requires_grad = image.requires_grad_(True)
assert type(image_requires_grad) is datapoints.Image
assert image.requires_grad
assert image_requires_grad.requires_grad
def test_detach_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16), requires_grad=True)
image_detached = image.detach()
assert type(image_detached) is datapoints.Image
def test_other_op_no_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
output = image * 2
assert type(output) is torch.Tensor
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(op):
image = datapoints.Image(torch.rand(3, 16, 16))
output = op(image)
assert type(output) is not datapoints.Image
def test_inplace_op_no_wrapping():
image = datapoints.Image(torch.rand(3, 16, 16))
output = image.add_(0)
assert type(output) is torch.Tensor
assert type(image) is datapoints.Image
def test_wrap_like():
image = datapoints.Image(torch.rand(3, 16, 16))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
output = image * 2
image_new = datapoints.Image.wrap_like(image, output)
assert type(image_new) is datapoints.Image
assert image_new.data_ptr() == output.data_ptr()
@pytest.mark.parametrize(
"datapoint",
[
datapoints.Image(torch.rand(3, 16, 16)),
datapoints.Video(torch.rand(2, 3, 16, 16)),
datapoints.BoundingBox([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)),
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize("requires_grad", [False, True])
def test_deepcopy(datapoint, requires_grad):
if requires_grad and not datapoint.dtype.is_floating_point:
return
datapoint.requires_grad_(requires_grad)
datapoint_deepcopied = deepcopy(datapoint)
assert datapoint_deepcopied is not datapoint
assert datapoint_deepcopied.data_ptr() != datapoint.data_ptr()
assert_equal(datapoint_deepcopied, datapoint)
assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad
assert datapoint_deepcopied.is_leaf
import pytest
import torch
from torchvision.prototype import datapoints as proto_datapoints
@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
([0.0], None, False),
([0.0], False, False),
([0.0], True, True),
(torch.tensor([0.0], requires_grad=False), None, False),
(torch.tensor([0.0], requires_grad=False), False, False),
(torch.tensor([0.0], requires_grad=False), True, True),
(torch.tensor([0.0], requires_grad=True), None, True),
(torch.tensor([0.0], requires_grad=True), False, False),
(torch.tensor([0.0], requires_grad=True), True, True),
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad
def test_isinstance():
assert isinstance(
proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor,
)
def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32)
assert type(label_to) is proto_datapoints.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories
def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label)
assert type(tensor_to) is torch.Tensor
assert tensor_to.dtype is torch.int32
def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone()
assert type(label_clone) is proto_datapoints.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories
def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad
label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is proto_datapoints.Label
assert label.requires_grad
assert label_requires_grad.requires_grad
def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
assert type(output) is torch.Tensor
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = op(label)
assert type(output) is not proto_datapoints.Label
def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = label.add_(0)
assert type(output) is torch.Tensor
assert type(label) is proto_datapoints.Label
def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
label_new = proto_datapoints.Label.wrap_like(label, output)
assert type(label_new) is proto_datapoints.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories
from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
......@@ -36,6 +36,7 @@ class Datapoint(torch.Tensor):
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
......@@ -132,6 +133,15 @@ class Datapoint(torch.Tensor):
with DisableTorchFunctionSubclass():
return super().dtype
def __deepcopy__(self: D, memo: Dict[int, Any]) -> D:
# We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
# attribute is cleared, so we need to refill it before we return.
# Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
# `BoundingBox.format` and `BoundingBox.spatial_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBox.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]
def horizontal_flip(self) -> Datapoint:
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment