Unverified Commit 3eafe77a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

expand ToDtype to support multiple conversions at once (#6756)

* expand ToDtype to support multiple conversions at once

* simplify
parent 6d774c6f
...@@ -1789,3 +1789,41 @@ class TestRandomResize: ...@@ -1789,3 +1789,41 @@ class TestRandomResize:
mock_resize.assert_called_with( mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
) )
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
import functools import functools
from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Type, Union from typing import Any, Callable, Dict, Sequence, Type, Union
import PIL.Image import PIL.Image
...@@ -144,14 +145,22 @@ class GaussianBlur(Transform): ...@@ -144,14 +145,22 @@ class GaussianBlur(Transform):
return F.gaussian_blur(inpt, self.kernel_size, **params) return F.gaussian_blur(inpt, self.kernel_size, **params)
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697 class ToDtype(Transform):
class ToDtype(Lambda): _transformed_types = (torch.Tensor,)
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
return dtype
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
self.dtype = dtype self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
def extra_repr(self) -> str: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) return inpt.to(self.dtype[type(inpt)])
class RemoveSmallBoundingBoxes(Transform): class RemoveSmallBoundingBoxes(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment