Unverified Commit d7d90f56 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

handle inplace operations in _Feature.__torch_function__ (#6671)

* prevent feature wrapping for inplace ops

* cleanup

* mypy

* refactor __torch_function__ to be more concise

* avoid double lookup

* fix normalize

* refactor normalize

* mypy
parent f7f38f1d
import pytest
import torch
from torchvision.prototype import features
......@@ -48,6 +49,19 @@ def test_clone_wrapping():
assert label_clone.categories is label.categories
def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = features.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad
label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is features.Label
assert label.requires_grad
assert label_requires_grad.requires_grad
def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
......@@ -58,6 +72,33 @@ def test_other_op_no_wrapping():
assert type(output) is torch.Tensor
@pytest.mark.parametrize(
"op",
[
lambda t: t.numpy(),
lambda t: t.tolist(),
lambda t: t.max(dim=-1),
],
)
def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
output = op(label)
assert type(output) is not features.Label
def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
output = label.add_(0)
assert type(output) is torch.Tensor
assert type(label) is features.Label
def test_new_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"])
......
......@@ -907,15 +907,15 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
def test_midlevel_normalize_output_type():
def test_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor)
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)
inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor)
assert type(output) is torch.Tensor
torch.testing.assert_close(inpt - 0.5, output)
......
......@@ -58,6 +58,16 @@ class _Feature(torch.Tensor):
**kwargs,
)
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.new_like(
input, output, dtype=output.dtype, device=output.device
),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
@classmethod
def __torch_function__(
cls,
......@@ -73,19 +83,15 @@ class _Feature(torch.Tensor):
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature`
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`_Feature`
use case, this has two downsides:
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators.
Exceptions to this are:
- :meth:`torch.Tensor.clone`
- :meth:`torch.Tensor.to`
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in :attr:`~_Feature._NO_WRAPPING_EXCEPTIONS`
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
......@@ -96,18 +102,21 @@ class _Feature(torch.Tensor):
with DisableTorchFunction():
output = func(*args, **kwargs or dict())
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
# `features.Image.__torch_function__` first. The check below makes sure that we do not try to wrap in such a
# case.
if not isinstance(args[0], cls):
return output
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(features.Image(...))` will invoke `features.Image.__torch_function__` with
# `args = (torch.Tensor(), features.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `features.Image`.
if wrapper and isinstance(args[0], cls):
return wrapper(cls, args[0], output) # type: ignore[no-any-return]
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
# will retain the input type. Thus, we need to unwrap here.
if isinstance(output, cls):
return output.as_subclass(torch.Tensor) # type: ignore[arg-type]
if func is torch.Tensor.clone:
return cls.new_like(args[0], output)
elif func is torch.Tensor.to:
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
else:
return output
def _make_repr(self, **kwargs: Any) -> str:
......
......@@ -12,12 +12,17 @@ normalize_image_tensor = _FT.normalize
def normalize(
inpt: features.TensorImageTypeJIT, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
if torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor)
else:
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
correct_type = features.is_simple_tensor(inpt) or isinstance(inpt, features.Image)
inpt = inpt.as_subclass(torch.Tensor) # type: ignore[arg-type]
if not correct_type:
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
def gaussian_blur_image_tensor(
......
......@@ -937,8 +937,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor
return tensor.sub_(mean).div_(std)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment