Unverified Commit 1402eb8e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add scale option to ToDtype. Remove ConvertDtype. (#7759)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent cc0f9d02
......@@ -234,7 +234,6 @@ Conversion
v2.PILToTensor
v2.ToImageTensor
ConvertImageDtype
v2.ConvertDtype
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
......
......@@ -29,7 +29,7 @@ def show(sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8)
image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
fig, ax = plt.subplots()
......
......@@ -27,7 +27,7 @@ from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_pil, to_image_tensor
from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
......@@ -601,7 +601,7 @@ def make_image_loader_for_interpolation(
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else:
image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)
image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True)
return datapoints.Image(image_tensor)
......
import itertools
import pathlib
import random
import re
import textwrap
import warnings
from collections import defaultdict
......@@ -105,7 +104,7 @@ def normalize_adapter(transform, input, device):
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
# normalize doesn't support integer images
value = F.convert_dtype(value, torch.float32)
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
return adapted_input
......@@ -146,7 +145,7 @@ class TestSmoke:
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBox(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None),
(transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
transforms.LinearTransformation(
......@@ -1326,61 +1325,6 @@ class TestRandomResize:
)
class TestToDtype:
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{
datapoints.Video: torch.float64,
datapoints.Image: torch.float64,
datapoints.BoundingBox: torch.float64,
},
),
(
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_call(self, dtype, expected_dtypes):
sample = dict(
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((), dtype=torch.float32)
transform = transforms.ToDtype({torch.Tensor: torch.float64})
assert transform(tensor).dtype is torch.float64
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})
class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
"inpt",
......
......@@ -191,7 +191,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
v2_transforms.ConvertDtype,
v2_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
......
......@@ -283,12 +283,12 @@ class TestKernels:
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel(
F.convert_dtype_image_tensor(input, dtype=torch.float32),
F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
*adapted_other_args,
**adapted_kwargs,
)
expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)
expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
assert_close(
actual,
......@@ -538,7 +538,6 @@ class TestDispatchers:
(F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic),
(F.convert_image_dtype, F.convert_dtype_image_tensor),
(F.to_grayscale, F.rgb_to_grayscale),
]
],
......@@ -547,24 +546,6 @@ def test_alias(alias, target):
assert alias is target
@pytest.mark.parametrize(
("info", "args_kwargs"),
make_info_args_kwargs_params(
KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
),
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
output = info.kernel(input, dtype)
assert output.dtype == dtype
assert output.device == input.device
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
......
import contextlib
import decimal
import inspect
import math
import re
......@@ -29,6 +30,7 @@ from common_utils import (
from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value
......@@ -66,11 +68,12 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
@cache
def _script(fn):
def _script(obj):
try:
return torch.jit.script(fn)
return torch.jit.script(obj)
except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
name = getattr(obj, "__name__", obj.__class__.__name__)
raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error
def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
......@@ -127,6 +130,7 @@ def check_kernel(
check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_unbatched=True,
expect_same_dtype=True,
**kwargs,
):
initial_input_version = input._version
......@@ -139,7 +143,8 @@ def check_kernel(
# check that no inplace operation happened
assert input._version == initial_input_version
assert output.dtype == input.dtype
if expect_same_dtype:
assert output.dtype == input.dtype
assert output.device == input.device
if check_cuda_vs_cpu:
......@@ -276,7 +281,7 @@ def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type):
def _check_transform_v1_compatibility(transform, input):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method, is scriptable, and the scripted version can be called without error."""
if not hasattr(transform, "_v1_transform_cls"):
if transform._v1_transform_cls is None:
return
if type(input) is not torch.Tensor:
......@@ -1697,3 +1702,193 @@ class TestCompose:
assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
class TestToDtype:
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.to_dtype_image_tensor, make_image_tensor),
(F.to_dtype_image_tensor, make_image),
(F.to_dtype_video, make_video),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale):
check_kernel(
kernel,
make_input(dtype=input_dtype, device=device),
expect_same_dtype=input_dtype is output_dtype,
dtype=output_dtype,
scale=scale,
)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.to_dtype_image_tensor, make_image_tensor),
(F.to_dtype_image_tensor, make_image),
(F.to_dtype_video, make_video),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale):
check_dispatcher(
F.to_dtype,
kernel,
make_input(dtype=input_dtype, device=device),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch=False,
dtype=output_dtype,
scale=scale,
)
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
@pytest.mark.parametrize("as_dict", (True, False))
def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict):
input = make_input(dtype=input_dtype, device=device)
if as_dict:
output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale)
def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype
output_dtype = dtype
if not scale:
return image.to(dtype)
if output_dtype == input_dtype:
return image
def fn(value):
if input_dtype.is_floating_point:
if output_dtype.is_floating_point:
return value
else:
return round(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
else:
input_max_value = torch.iinfo(input_dtype).max
if output_dtype.is_floating_point:
return float(decimal.Decimal(value) / input_max_value)
else:
output_max_value = torch.iinfo(output_dtype).max
if input_max_value > output_max_value:
factor = (input_max_value + 1) // (output_max_value + 1)
return value / factor
else:
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
input = make_image(dtype=input_dtype, device=device)
out = F.to_dtype(input, dtype=output_dtype, scale=scale)
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
torch.testing.assert_close(out, expected, atol=1, rtol=0)
else:
torch.testing.assert_close(out, expected)
def was_scaled(self, inpt):
# this assumes the target dtype is float
return inpt.max() <= 1
def make_inpt_with_bbox_and_mask(self, make_input):
H, W = 10, 10
inpt_dtype = torch.uint8
bbox_dtype = torch.float32
mask_dtype = torch.bool
sample = {
"inpt": make_input(size=(H, W), dtype=inpt_dtype),
"bbox": make_bounding_box(size=(H, W), dtype=bbox_dtype),
"mask": make_detection_mask(size=(H, W), dtype=mask_dtype),
}
return sample, inpt_dtype, bbox_dtype, mask_dtype
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
@pytest.mark.parametrize("scale", (True, False))
def test_dtype_not_a_dict(self, make_input, scale):
# assert only inpt gets transformed when dtype isn't a dict
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(dtype=torch.float32, scale=scale)(sample)
assert out["inpt"].dtype != inpt_dtype
assert out["inpt"].dtype == torch.float32
if scale:
assert self.was_scaled(out["inpt"])
else:
assert not self.was_scaled(out["inpt"])
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
def test_others_catch_all_and_none(self, make_input):
# make sure "others" works as a catch-all and that None means no conversion
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(dtype={datapoints.Mask: torch.int64, "others": None})(sample)
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype != mask_dtype
assert out["mask"].dtype == torch.int64
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
def test_typical_use_case(self, make_input):
# Typical use-case: want to convert dtype and scale for inpt and just dtype for masks.
# This just makes sure we now have a decent API for this
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(
dtype={type(sample["inpt"]): torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True
)(sample)
assert out["inpt"].dtype != inpt_dtype
assert out["inpt"].dtype == torch.float32
assert self.was_scaled(out["inpt"])
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype != mask_dtype
assert out["mask"].dtype == torch.int64
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
def test_errors_warnings(self, make_input):
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
with pytest.raises(ValueError, match="No dtype was specified for"):
out = transforms.ToDtype(dtype={datapoints.Mask: torch.float32})(sample)
with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, datapoints.Image: torch.float32})
with pytest.warns(UserWarning, match="no scaling will be done"):
out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample)
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype
......@@ -364,16 +364,6 @@ DISPATCHER_INFOS = [
xfail_jit_python_scalar_arg("std"),
],
),
DispatcherInfo(
F.convert_dtype,
kernels={
datapoints.Image: F.convert_dtype_image_tensor,
datapoints.Video: F.convert_dtype_video,
},
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo(
F.uniform_temporal_subsample,
kernels={
......
import decimal
import functools
import itertools
......@@ -27,7 +26,6 @@ from common_utils import (
mark_framework_limitation,
TestMark,
)
from torch.utils._pytree import tree_map
from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
......@@ -1566,7 +1564,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)(
F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True)
for output_pil in output
)
......@@ -1667,125 +1665,6 @@ KERNEL_INFOS.extend(
)
def sample_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
):
if input_dtype.is_floating_point and output_dtype == torch.int64:
# conversion cannot be performed safely
continue
for image_loader in make_image_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[input_dtype]
):
yield ArgsKwargs(image_loader, dtype=output_dtype)
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
input_dtype = image.dtype
output_dtype = dtype
if output_dtype == input_dtype:
return image
def fn(value):
if input_dtype.is_floating_point:
if output_dtype.is_floating_point:
return value
else:
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
else:
input_max_value = torch.iinfo(input_dtype).max
if output_dtype.is_floating_point:
return float(decimal.Decimal(value) / input_max_value)
else:
output_max_value = torch.iinfo(output_dtype).max
if input_max_value > output_max_value:
factor = (input_max_value + 1) // (output_max_value + 1)
return value // factor
else:
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)
def reference_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
],
repeat=2,
):
if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
continue
if input_dtype.is_floating_point:
data = [0.0, 0.5, 1.0]
else:
max_value = torch.iinfo(input_dtype).max
data = [0, max_value // 2, max_value]
image = torch.tensor(data, dtype=input_dtype)
yield ArgsKwargs(image, dtype=output_dtype)
def sample_inputs_convert_dtype_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader)
skip_dtype_consistency = TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
)
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_dtype_image_tensor,
sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
reference_fn=reference_convert_dtype_image_tensor,
reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
test_marks=[
skip_dtype_consistency,
TestMark(
("TestKernels", "test_against_reference"),
pytest.mark.xfail(reason="Conversion overflows"),
condition=lambda args_kwargs: (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
),
),
],
),
KernelInfo(
F.convert_dtype_video,
sample_inputs_fn=sample_inputs_convert_dtype_video,
test_marks=[
skip_dtype_consistency,
],
),
]
)
def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]):
yield ArgsKwargs(video_loader, num_samples=2)
......
......@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
......
......@@ -31,10 +31,13 @@ class ConvertBoundingBoxFormat(Transform):
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
class ConvertDtype(Transform):
"""[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly.
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertDtype transform
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
......@@ -55,21 +58,14 @@ class ConvertDtype(Transform):
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any]
) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]:
return F.convert_dtype(inpt, self.dtype)
# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ConvertImageDtype = ConvertDtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)
class ClampBoundingBox(Transform):
......
......@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size
from ._utils import _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box
......@@ -225,36 +225,76 @@ class GaussianBlur(Transform):
class ToDtype(Transform):
"""[BETA] Converts the input to a specific dtype - this does not scale values.
"""[BETA] Converts the input to a specific dtype, optionally scaling the values for images or videos.
.. v2betastatus:: ToDtype transform
.. note::
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-datapoint conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Video:
torch.float64}``.
``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. Default: ``False``.
"""
_transformed_types = (torch.Tensor,)
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
def __init__(
self, dtype: Union[torch.dtype, Dict[Union[Type, str], Optional[torch.dtype]]], scale: bool = False
) -> None:
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
if not isinstance(dtype, (dict, torch.dtype)):
raise ValueError(f"dtype must be a dict or a torch.dtype, got {type(dtype)} instead")
if (
isinstance(dtype, dict)
and torch.Tensor in dtype
and any(cls in dtype for cls in [datapoints.Image, datapoints.Video])
):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype
self.scale = scale
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
dtype = self.dtype[type(inpt)]
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt
dtype: Optional[torch.dtype] = self.dtype
elif type(inpt) in self.dtype:
dtype = self.dtype[type(inpt)]
elif "others" in self.dtype:
dtype = self.dtype["others"]
else:
raise ValueError(
f"No dtype was specified for type {type(inpt)}. "
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, "
'you can use "others" as a catch-all key '
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
"scale was set to True but no dtype was specified for images or videos: no scaling will be done."
)
return inpt
return inpt.to(dtype=dtype)
return F.to_dtype(inpt, dtype=dtype, scale=self.scale)
class SanitizeBoundingBox(Transform):
......
......@@ -5,10 +5,10 @@ from ._utils import is_simple_tensor # usort: skip
from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
convert_image_dtype,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
......
......@@ -9,7 +9,7 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, convert_dtype_image_tensor
from ._meta import _num_value_bits, to_dtype_image_tensor
from ._utils import is_simple_tensor
......@@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image
orig_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.float32)
image = to_dtype_image_tensor(image, torch.float32, scale=True)
image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
......@@ -359,7 +359,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)
return convert_dtype_image_tensor(image_hue_adj, orig_dtype)
return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True)
adjust_hue_image_pil = _FP.adjust_hue
......@@ -393,7 +393,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
# The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer).
# Since the gamma is non-negative, the output remains at [0, 1] scale.
if not torch.is_floating_point(image):
output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma)
output = to_dtype_image_tensor(image, torch.float32, scale=True).pow_(gamma)
else:
output = image.pow(gamma)
......@@ -402,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
# of the output can go beyond [0, 1].
output = output.mul_(gain).clamp_(0.0, 1.0)
return convert_dtype_image_tensor(output, image.dtype)
return to_dtype_image_tensor(output, image.dtype, scale=True)
adjust_gamma_image_pil = _FP.adjust_gamma
......@@ -565,7 +565,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8)
image = to_dtype_image_tensor(image, torch.uint8, scale=True)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
......@@ -616,7 +616,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
output = torch.where(valid_equalization, equalized_image, image)
return convert_dtype_image_tensor(output, output_dtype)
return to_dtype_image_tensor(output, output_dtype, scale=True)
equalize_image_pil = _FP.equalize
......
......@@ -296,9 +296,12 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)
float_input = image.is_floating_point()
if torch.jit.is_scripting():
......@@ -345,30 +348,28 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
convert_image_dtype = convert_dtype_image_tensor
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)
def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
return convert_dtype_image_tensor(video, dtype)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)
def convert_dtype(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_dtype)
_log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return convert_dtype_image_tensor(inpt, dtype)
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment