Unverified Commit 1353df9e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[FBcode->GH] [reland] rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) (#7062)

[FBcode->GH]
https://www.internalfb.com/diff/D41268423

Co-authored-by: default avatarSamantha Andow <samdow@meta.com>
parent 2fb9c49c
......@@ -5,7 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type
import PIL.Image
import torch
from torch._C import DisableTorchFunction
from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size
from torchvision.transforms import InterpolationMode
......@@ -87,7 +87,7 @@ class Datapoint(torch.Tensor):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with DisableTorchFunction():
with DisableTorchFunctionSubclass():
output = func(*args, **kwargs or dict())
wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)
......@@ -129,22 +129,22 @@ class Datapoint(torch.Tensor):
# this way we return the result without passing into __torch_function__
@property
def shape(self) -> _size: # type: ignore[override]
with DisableTorchFunction():
with DisableTorchFunctionSubclass():
return super().shape
@property
def ndim(self) -> int: # type: ignore[override]
with DisableTorchFunction():
with DisableTorchFunctionSubclass():
return super().ndim
@property
def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
with DisableTorchFunction():
with DisableTorchFunctionSubclass():
return super().device
@property
def dtype(self) -> _dtype: # type: ignore[override]
with DisableTorchFunction():
with DisableTorchFunctionSubclass():
return super().dtype
def horizontal_flip(self) -> Datapoint:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment