Unverified Commit bf6a8dc2 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Simplify _NO_WRAPPING_EXCEPTIONS (#7806)

parent f2b6f43a
...@@ -209,4 +209,3 @@ def test_deepcopy(datapoint, requires_grad): ...@@ -209,4 +209,3 @@ def test_deepcopy(datapoint, requires_grad):
assert type(datapoint_deepcopied) is type(datapoint) assert type(datapoint_deepcopied) is type(datapoint)
assert datapoint_deepcopied.requires_grad is requires_grad assert datapoint_deepcopied.requires_grad is requires_grad
assert datapoint_deepcopied.is_leaf
...@@ -33,14 +33,9 @@ class Datapoint(torch.Tensor): ...@@ -33,14 +33,9 @@ class Datapoint(torch.Tensor):
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls) return tensor.as_subclass(cls)
_NO_WRAPPING_EXCEPTIONS = { # The ops in this set are those that should *preserve* the Datapoint type,
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), # i.e. they are exceptions to the "no wrapping" rule.
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output), _NO_WRAPPING_EXCEPTIONS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
torch.Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output,
}
@classmethod @classmethod
def __torch_function__( def __torch_function__(
...@@ -76,22 +71,21 @@ class Datapoint(torch.Tensor): ...@@ -76,22 +71,21 @@ class Datapoint(torch.Tensor):
with DisableTorchFunctionSubclass(): with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict()) output = func(*args, **kwargs or dict())
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) if func in cls._NO_WRAPPING_EXCEPTIONS and isinstance(args[0], cls):
# Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be # We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will # an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example, # invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with # `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would # `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`. # be wrapped into a `datapoints.Image`.
if wrapper and isinstance(args[0], cls): return cls.wrap_like(args[0], output)
return wrapper(cls, args[0], output)
# Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`, if isinstance(output, cls):
# will retain the input type. Thus, we need to unwrap here. # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
if isinstance(output, cls): # so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return output.as_subclass(torch.Tensor) return output.as_subclass(torch.Tensor)
return output return output
def _make_repr(self, **kwargs: Any) -> str: def _make_repr(self, **kwargs: Any) -> str:
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532. # This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment