Unverified Commit e1aacdd9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update `ToDtype` to avoid unnecessary `to()` calls and fixing types on `Transform` (#6773)

* Fix `ToDtype` to avoid errors when a type is not defined.

* Nit `(features.is_simple_tensor, features._Feature)` to `(Tensor,)`

* Fixing linter

* Adding comment.

* Switch back to indexing. Python's default dict seems to have a nasty behaviour.
parent 8ec7a70f
...@@ -157,7 +157,10 @@ class ToDtype(Transform): ...@@ -157,7 +157,10 @@ class ToDtype(Transform):
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(self.dtype[type(inpt)]) dtype = self.dtype[type(inpt)]
if dtype is None:
return inpt
return inpt.to(dtype=dtype)
class RemoveSmallBoundingBoxes(Transform): class RemoveSmallBoundingBoxes(Transform):
......
...@@ -5,7 +5,6 @@ import PIL.Image ...@@ -5,7 +5,6 @@ import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import _isinstance from torchvision.prototype.transforms._utils import _isinstance
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -13,11 +12,8 @@ from torchvision.utils import _log_api_usage_once ...@@ -13,11 +12,8 @@ from torchvision.utils import _log_api_usage_once
class Transform(nn.Module): class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation # Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = ( # We support both Types and callables that are able to do further checks on the type of the input.
features.is_simple_tensor, _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
features._Feature,
PIL.Image.Image,
)
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment