"csrc/common.cpp" did not exist on "387082e1bb6fdc81ec7c04700927a03abb12ad42"
Unverified Commit 3eafe77a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

expand ToDtype to support multiple conversions at once (#6756)

* expand ToDtype to support multiple conversions at once

* simplify
parent 6d774c6f
......@@ -1789,3 +1789,41 @@ class TestRandomResize:
mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{torch.Tensor: torch.float64, features.Image: torch.float64, features.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, features.Image: torch.float32, features.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Type, Union
import PIL.Image
......@@ -144,14 +145,22 @@ class GaussianBlur(Transform):
return F.gaussian_blur(inpt, self.kernel_size, **params)
# TODO: Enhance as described at https://github.com/pytorch/vision/issues/6697
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
class ToDtype(Transform):
_transformed_types = (torch.Tensor,)
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
return dtype
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))
def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.to(self.dtype[type(inpt)])
class RemoveSmallBoundingBoxes(Transform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment