Unverified Commit 1353df9e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[FBcode->GH] [reland] rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) (#7062)

[FBcode->GH]
https://www.internalfb.com/diff/D41268423

Co-authored-by: default avatarSamantha Andow <samdow@meta.com>
parent 2fb9c49c
...@@ -5,7 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type ...@@ -5,7 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type
import PIL.Image import PIL.Image
import torch import torch
from torch._C import DisableTorchFunction from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size from torch.types import _device, _dtype, _size
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -87,7 +87,7 @@ class Datapoint(torch.Tensor): ...@@ -87,7 +87,7 @@ class Datapoint(torch.Tensor):
if not all(issubclass(cls, t) for t in types): if not all(issubclass(cls, t) for t in types):
return NotImplemented return NotImplemented
with DisableTorchFunction(): with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict()) output = func(*args, **kwargs or dict())
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func) wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
...@@ -129,22 +129,22 @@ class Datapoint(torch.Tensor): ...@@ -129,22 +129,22 @@ class Datapoint(torch.Tensor):
# this way we return the result without passing into __torch_function__ # this way we return the result without passing into __torch_function__
@property @property
def shape(self) -> _size: # type: ignore[override] def shape(self) -> _size: # type: ignore[override]
with DisableTorchFunction(): with DisableTorchFunctionSubclass():
return super().shape return super().shape
@property @property
def ndim(self) -> int: # type: ignore[override] def ndim(self) -> int: # type: ignore[override]
with DisableTorchFunction(): with DisableTorchFunctionSubclass():
return super().ndim return super().ndim
@property @property
def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override] def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
with DisableTorchFunction(): with DisableTorchFunctionSubclass():
return super().device return super().device
@property @property
def dtype(self) -> _dtype: # type: ignore[override] def dtype(self) -> _dtype: # type: ignore[override]
with DisableTorchFunction(): with DisableTorchFunctionSubclass():
return super().dtype return super().dtype
def horizontal_flip(self) -> Datapoint: def horizontal_flip(self) -> Datapoint:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment