"docs/vscode:/vscode.git/clone" did not exist on "6e1af3a777ba0f27a6071861ec916e7dc35efe9d"
Unverified Commit e1aacdd9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update `ToDtype` to avoid unnecessary `to()` calls and fixing types on `Transform` (#6773)

* Fix `ToDtype` to avoid errors when a type is not defined.

* Nit `(features.is_simple_tensor, features._Feature)` to `(Tensor,)`

* Fixing linter

* Adding comment.

* Switch back to indexing. Python's default dict seems to have a nasty behaviour.
parent 8ec7a70f
......@@ -157,7 +157,10 @@ class ToDtype(Transform):
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(self.dtype[type(inpt)])
dtype = self.dtype[type(inpt)]
if dtype is None:
return inpt
return inpt.to(dtype=dtype)
class RemoveSmallBoundingBoxes(Transform):
......
......@@ -5,7 +5,6 @@ import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.utils import _log_api_usage_once
......@@ -13,11 +12,8 @@ from torchvision.utils import _log_api_usage_once
class Transform(nn.Module):
# Class attribute defining transformed types. Other types are passed-through without any transformation
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
features.is_simple_tensor,
features._Feature,
PIL.Image.Image,
)
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
def __init__(self) -> None:
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment