You need to sign in or sign up before continuing.
Unverified Commit 3eafe77a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

expand ToDtype to support multiple conversions at once (#6756)

* expand ToDtype to support multiple conversions at once

* simplify
parent 6d774c6f
......@@ -1789,3 +1789,41 @@ class TestRandomResize:
mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Type, Union
import PIL.Image
......@@ -144,14 +145,22 @@ class GaussianBlur(Transform):
return F.gaussian_blur(inpt, self.kernel_size, **params)
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
class ToDtype(Transform):
_transformed_types = (torch.Tensor,)
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
return dtype
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(self.dtype[type(inpt)])
class RemoveSmallBoundingBoxes(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment