Unverified Commit 2bc8a14d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix requires_grad passthrough (#7138)

parent 455eda68
...@@ -3,6 +3,25 @@ import torch ...@@ -3,6 +3,25 @@ import torch
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
@pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"),
[
([0.0], None, False),
([0.0], False, False),
([0.0], True, True),
(torch.tensor([0.0], requires_grad=False), None, False),
(torch.tensor([0.0], requires_grad=False), False, False),
(torch.tensor([0.0], requires_grad=False), True, True),
(torch.tensor([0.0], requires_grad=True), None, True),
(torch.tensor([0.0], requires_grad=True), False, False),
(torch.tensor([0.0], requires_grad=True), True, True),
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Label(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad
def test_isinstance(): def test_isinstance():
assert isinstance( assert isinstance(
datapoints.Label([0, 1, 0], categories=["foo", "bar"]), datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
......
...@@ -34,7 +34,7 @@ class BoundingBox(Datapoint): ...@@ -34,7 +34,7 @@ class BoundingBox(Datapoint):
spatial_size: Tuple[int, int], spatial_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> BoundingBox: ) -> BoundingBox:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
......
...@@ -23,8 +23,10 @@ class Datapoint(torch.Tensor): ...@@ -23,8 +23,10 @@ class Datapoint(torch.Tensor):
data: Any, data: Any,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if requires_grad is None:
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a # FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
...@@ -36,7 +38,7 @@ class Datapoint(torch.Tensor): ...@@ -36,7 +38,7 @@ class Datapoint(torch.Tensor):
data: Any, data: Any,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> Datapoint: ) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(Datapoint) return tensor.as_subclass(Datapoint)
......
...@@ -21,7 +21,7 @@ class Image(Datapoint): ...@@ -21,7 +21,7 @@ class Image(Datapoint):
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> Image: ) -> Image:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if tensor.ndim < 2: if tensor.ndim < 2:
......
...@@ -27,7 +27,7 @@ class _LabelBase(Datapoint): ...@@ -27,7 +27,7 @@ class _LabelBase(Datapoint):
categories: Optional[Sequence[str]] = None, categories: Optional[Sequence[str]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> L: ) -> L:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories) return cls._wrap(tensor, categories=categories)
......
...@@ -19,7 +19,7 @@ class Mask(Datapoint): ...@@ -19,7 +19,7 @@ class Mask(Datapoint):
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> Mask: ) -> Mask:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor) return cls._wrap(tensor)
......
...@@ -20,7 +20,7 @@ class Video(Datapoint): ...@@ -20,7 +20,7 @@ class Video(Datapoint):
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: Optional[bool] = None,
) -> Video: ) -> Video:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4: if data.ndim < 4:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment