Unverified Commit f7f38f1d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup features._Feature (#5806)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 7f4c55b1
from __future__ import annotations from __future__ import annotations
from types import ModuleType from types import ModuleType
from typing import Any, Callable, cast, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torch._C import _TensorBase, DisableTorchFunction from torch._C import DisableTorchFunction
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
F = TypeVar("F", bound="_Feature") F = TypeVar("F", bound="_Feature")
FillType = Union[int, float, Sequence[int], Sequence[float], None] FillType = Union[int, float, Sequence[int], Sequence[float], None]
FillTypeJIT = Union[int, float, List[float], None] FillTypeJIT = Union[int, float, List[float], None]
...@@ -28,13 +29,14 @@ class _Feature(torch.Tensor): ...@@ -28,13 +29,14 @@ class _Feature(torch.Tensor):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> F: ) -> F:
return cast( return (
F, torch.as_tensor( # type: ignore[return-value]
torch.Tensor._make_subclass( data,
cast(_TensorBase, cls), dtype=dtype, # type: ignore[arg-type]
torch.as_tensor(data, dtype=dtype, device=device), # type: ignore[arg-type] device=device, # type: ignore[arg-type]
requires_grad, )
), .as_subclass(cls) # type: ignore[arg-type]
.requires_grad_(requires_grad)
) )
@classmethod @classmethod
...@@ -82,12 +84,17 @@ class _Feature(torch.Tensor): ...@@ -82,12 +84,17 @@ class _Feature(torch.Tensor):
Exceptions to this are: Exceptions to this are:
- :func:`torch.clone` - :meth:`torch.Tensor.clone`
- :meth:`torch.Tensor.to` - :meth:`torch.Tensor.to`
""" """
kwargs = kwargs or dict() # Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with DisableTorchFunction(): with DisableTorchFunction():
output = func(*args, **kwargs) output = func(*args, **kwargs or dict())
# The __torch_function__ protocol will invoke this method on all types involved in the computation by walking # The __torch_function__ protocol will invoke this method on all types involved in the computation by walking
# the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke # the MRO upwards. For example, `torch.Tensor(...).to(features.Image(...))` will invoke
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment