Unverified Commit d7d90f56 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

handle inplace operations in _Feature.__torch_function__ (#6671)

* prevent feature wrapping for inplace ops

* cleanup

* mypy

* refactor __torch_function__ to be more concise

* avoid double lookup

* fix normalize

* refactor normalize

* mypy
parent f7f38f1d
import pytest
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
...@@ -48,6 +49,19 @@ def test_clone_wrapping(): ...@@ -48,6 +49,19 @@ def test_clone_wrapping():
assert label_clone.categories is label.categories assert label_clone.categories is label.categories
def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = features.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad
label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is features.Label
assert label.requires_grad
assert label_requires_grad.requires_grad
def test_other_op_no_wrapping(): def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = features.Label(tensor, categories=["foo", "bar"])
...@@ -58,6 +72,33 @@ def test_other_op_no_wrapping(): ...@@ -58,6 +72,33 @@ def test_other_op_no_wrapping():
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
output = op(label)
assert type(output) is not features.Label
def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
output = label.add_(0)
assert type(output) is torch.Tensor
assert type(label) is features.Label
def test_new_like(): def test_new_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = features.Label(tensor, categories=["foo", "bar"])
......
...@@ -907,15 +907,15 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s ...@@ -907,15 +907,15 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
def test_midlevel_normalize_output_type(): def test_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32) inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor) assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
inpt = make_image(color_space=features.ColorSpace.RGB) inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]) output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor) assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output) torch.testing.assert_close(inpt - 0.5, output)
......
...@@ -58,6 +58,16 @@ class _Feature(torch.Tensor): ...@@ -58,6 +58,16 @@ class _Feature(torch.Tensor):
**kwargs, **kwargs,
) )
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.new_like(
input, output, dtype=output.dtype, device=output.device
),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
@classmethod @classmethod
def __torch_function__( def __torch_function__(
cls, cls,
...@@ -73,19 +83,15 @@ class _Feature(torch.Tensor): ...@@ -73,19 +83,15 @@ class _Feature(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call. ``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature` The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature`
use case, this has two downsides: use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. 1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them. ``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output. 2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS`
Exceptions to this are:
- :meth:`torch.Tensor.clone`
- :meth:`torch.Tensor.to`
""" """
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality. # need to reimplement the functionality.
...@@ -96,18 +102,21 @@ class _Feature(torch.Tensor): ...@@ -96,18 +102,21 @@ class _Feature(torch.Tensor):
with DisableTorchFunction(): with DisableTorchFunction():
output = func(*args, **kwargs or dict()) output = func(*args, **kwargs or dict())
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# case. # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
if not isinstance(args[0], cls): # `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with
return output # `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `features.Image`.
if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output) # type: ignore[no-any-return]
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor) # type: ignore[arg-type]
if func is torch.Tensor.clone:
return cls.new_like(args[0], output)
elif func is torch.Tensor.to:
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
else:
return output return output
def _make_repr(self, **kwargs: Any) -> str: def _make_repr(self, **kwargs: Any) -> str:
......
...@@ -12,12 +12,17 @@ normalize_image_tensor = _FT.normalize ...@@ -12,12 +12,17 @@ normalize_image_tensor = _FT.normalize
def normalize( def normalize(
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor): if torch.jit.is_scripting():
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}") correct_type = isinstance(inpt, torch.Tensor)
else: else:
# Image instance after normalization is not Image anymore due to unknown data range correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image)
# Thus we return Tensor for input Image inpt = inpt.as_subclass(torch.Tensor) # type: ignore[arg-type]
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) if not correct_type:
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
......
...@@ -937,8 +937,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -937,8 +937,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
mean = mean.view(-1, 1, 1) mean = mean.view(-1, 1, 1)
if std.ndim == 1: if std.ndim == 1:
std = std.view(-1, 1, 1) std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std) return tensor.sub_(mean).div_(std)
return tensor
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment