Unverified Commit 1402eb8e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add scale option to ToDtype. Remove ConvertDtype. (#7759)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent cc0f9d02
...@@ -234,7 +234,6 @@ Conversion ...@@ -234,7 +234,6 @@ Conversion
v2.PILToTensor v2.PILToTensor
v2.ToImageTensor v2.ToImageTensor
ConvertImageDtype ConvertImageDtype
v2.ConvertDtype
v2.ConvertImageDtype v2.ConvertImageDtype
v2.ToDtype v2.ToDtype
v2.ConvertBoundingBoxFormat v2.ConvertBoundingBoxFormat
......
...@@ -29,7 +29,7 @@ def show(sample): ...@@ -29,7 +29,7 @@ def show(sample):
image, target = sample image, target = sample
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image) image = F.to_image_tensor(image)
image = F.convert_dtype(image, torch.uint8) image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
fig, ax = plt.subplots() fig, ax = plt.subplots()
......
...@@ -27,7 +27,7 @@ from PIL import Image ...@@ -27,7 +27,7 @@ from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_pil, to_image_tensor from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
...@@ -601,7 +601,7 @@ def make_image_loader_for_interpolation( ...@@ -601,7 +601,7 @@ def make_image_loader_for_interpolation(
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else: else:
image_tensor = image_tensor.to(device=device) image_tensor = image_tensor.to(device=device)
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype) image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True)
return datapoints.Image(image_tensor) return datapoints.Image(image_tensor)
......
import itertools import itertools
import pathlib import pathlib
import random import random
import re
import textwrap import textwrap
import warnings import warnings
from collections import defaultdict from collections import defaultdict
...@@ -105,7 +104,7 @@ def normalize_adapter(transform, input, device): ...@@ -105,7 +104,7 @@ def normalize_adapter(transform, input, device):
continue continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
# normalize doesn't support integer images # normalize doesn't support integer images
value = F.convert_dtype(value, torch.float32) value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value adapted_input[key] = value
return adapted_input return adapted_input
...@@ -146,7 +145,7 @@ class TestSmoke: ...@@ -146,7 +145,7 @@ class TestSmoke:
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBox(), None), (transforms.ClampBoundingBox(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None), (transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None), (transforms.GaussianBlur(kernel_size=3), None),
( (
transforms.LinearTransformation( transforms.LinearTransformation(
...@@ -1326,61 +1325,6 @@ class TestRandomResize: ...@@ -1326,61 +1325,6 @@ class TestRandomResize:
) )
class TestToDtype:
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{
datapoints.Video: torch.float64,
datapoints.Image: torch.float64,
datapoints.BoundingBox: torch.float64,
},
),
(
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_call(self, dtype, expected_dtypes):
sample = dict(
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((), dtype=torch.float32)
transform = transforms.ToDtype({torch.Tensor: torch.float64})
assert transform(tensor).dtype is torch.float64
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})
class TestUniformTemporalSubsample: class TestUniformTemporalSubsample:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt", "inpt",
......
...@@ -191,7 +191,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -191,7 +191,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None), closeness_kwargs=dict(rtol=None, atol=None),
), ),
ConsistencyConfig( ConsistencyConfig(
v2_transforms.ConvertDtype, v2_transforms.ConvertImageDtype,
legacy_transforms.ConvertImageDtype, legacy_transforms.ConvertImageDtype,
[ [
ArgsKwargs(torch.float16), ArgsKwargs(torch.float16),
......
...@@ -283,12 +283,12 @@ class TestKernels: ...@@ -283,12 +283,12 @@ class TestKernels:
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs) adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel( actual = info.kernel(
F.convert_dtype_image_tensor(input, dtype=torch.float32), F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
*adapted_other_args, *adapted_other_args,
**adapted_kwargs, **adapted_kwargs,
) )
expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32) expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
assert_close( assert_close(
actual, actual,
...@@ -538,7 +538,6 @@ class TestDispatchers: ...@@ -538,7 +538,6 @@ class TestDispatchers:
(F.get_image_num_channels, F.get_num_channels), (F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil), (F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic), (F.elastic_transform, F.elastic),
(F.convert_image_dtype, F.convert_dtype_image_tensor),
(F.to_grayscale, F.rgb_to_grayscale), (F.to_grayscale, F.rgb_to_grayscale),
] ]
], ],
...@@ -547,24 +546,6 @@ def test_alias(alias, target): ...@@ -547,24 +546,6 @@ def test_alias(alias, target):
assert alias is target assert alias is target
@pytest.mark.parametrize(
("info", "args_kwargs"),
make_info_args_kwargs_params(
KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
),
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
output = info.kernel(input, dtype)
assert output.dtype == dtype
assert output.device == input.device
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("num_channels", [1, 3]) @pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels): def test_normalize_image_tensor_stats(device, num_channels):
......
import contextlib import contextlib
import decimal
import inspect import inspect
import math import math
import re import re
...@@ -29,6 +30,7 @@ from common_utils import ( ...@@ -29,6 +30,7 @@ from common_utils import (
from torch import nn from torch import nn
from torch.testing import assert_close from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
...@@ -66,11 +68,12 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs): ...@@ -66,11 +68,12 @@ def _check_kernel_cuda_vs_cpu(kernel, input, *args, rtol, atol, **kwargs):
@cache @cache
def _script(fn): def _script(obj):
try: try:
return torch.jit.script(fn) return torch.jit.script(obj)
except Exception as error: except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error name = getattr(obj, "__name__", obj.__class__.__name__)
raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error
def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs): def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
...@@ -127,6 +130,7 @@ def check_kernel( ...@@ -127,6 +130,7 @@ def check_kernel(
check_cuda_vs_cpu=True, check_cuda_vs_cpu=True,
check_scripted_vs_eager=True, check_scripted_vs_eager=True,
check_batched_vs_unbatched=True, check_batched_vs_unbatched=True,
expect_same_dtype=True,
**kwargs, **kwargs,
): ):
initial_input_version = input._version initial_input_version = input._version
...@@ -139,7 +143,8 @@ def check_kernel( ...@@ -139,7 +143,8 @@ def check_kernel(
# check that no inplace operation happened # check that no inplace operation happened
assert input._version == initial_input_version assert input._version == initial_input_version
assert output.dtype == input.dtype if expect_same_dtype:
assert output.dtype == input.dtype
assert output.device == input.device assert output.device == input.device
if check_cuda_vs_cpu: if check_cuda_vs_cpu:
...@@ -276,7 +281,7 @@ def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type): ...@@ -276,7 +281,7 @@ def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type):
def _check_transform_v1_compatibility(transform, input): def _check_transform_v1_compatibility(transform, input):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method, is scriptable, and the scripted version can be called without error.""" ``get_params`` method, is scriptable, and the scripted version can be called without error."""
if not hasattr(transform, "_v1_transform_cls"): if transform._v1_transform_cls is None:
return return
if type(input) is not torch.Tensor: if type(input) is not torch.Tensor:
...@@ -1697,3 +1702,193 @@ class TestCompose: ...@@ -1697,3 +1702,193 @@ class TestCompose:
assert isinstance(output, tuple) and len(output) == 2 assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image assert output[0] is image
assert output[1] is label assert output[1] is label
class TestToDtype:
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.to_dtype_image_tensor, make_image_tensor),
(F.to_dtype_image_tensor, make_image),
(F.to_dtype_video, make_video),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, scale):
check_kernel(
kernel,
make_input(dtype=input_dtype, device=device),
expect_same_dtype=input_dtype is output_dtype,
dtype=output_dtype,
scale=scale,
)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.to_dtype_image_tensor, make_image_tensor),
(F.to_dtype_image_tensor, make_image),
(F.to_dtype_video, make_video),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_dispatcher(self, kernel, make_input, input_dtype, output_dtype, device, scale):
check_dispatcher(
F.to_dtype,
kernel,
make_input(dtype=input_dtype, device=device),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch=False,
dtype=output_dtype,
scale=scale,
)
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_bounding_box, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
@pytest.mark.parametrize("as_dict", (True, False))
def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict):
input = make_input(dtype=input_dtype, device=device)
if as_dict:
output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype, input, dtype=output_dtype, scale=scale)
def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype
output_dtype = dtype
if not scale:
return image.to(dtype)
if output_dtype == input_dtype:
return image
def fn(value):
if input_dtype.is_floating_point:
if output_dtype.is_floating_point:
return value
else:
return round(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
else:
input_max_value = torch.iinfo(input_dtype).max
if output_dtype.is_floating_point:
return float(decimal.Decimal(value) / input_max_value)
else:
output_max_value = torch.iinfo(output_dtype).max
if input_max_value > output_max_value:
factor = (input_max_value + 1) // (output_max_value + 1)
return value / factor
else:
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
input = make_image(dtype=input_dtype, device=device)
out = F.to_dtype(input, dtype=output_dtype, scale=scale)
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
torch.testing.assert_close(out, expected, atol=1, rtol=0)
else:
torch.testing.assert_close(out, expected)
def was_scaled(self, inpt):
# this assumes the target dtype is float
return inpt.max() <= 1
def make_inpt_with_bbox_and_mask(self, make_input):
H, W = 10, 10
inpt_dtype = torch.uint8
bbox_dtype = torch.float32
mask_dtype = torch.bool
sample = {
"inpt": make_input(size=(H, W), dtype=inpt_dtype),
"bbox": make_bounding_box(size=(H, W), dtype=bbox_dtype),
"mask": make_detection_mask(size=(H, W), dtype=mask_dtype),
}
return sample, inpt_dtype, bbox_dtype, mask_dtype
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
@pytest.mark.parametrize("scale", (True, False))
def test_dtype_not_a_dict(self, make_input, scale):
# assert only inpt gets transformed when dtype isn't a dict
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(dtype=torch.float32, scale=scale)(sample)
assert out["inpt"].dtype != inpt_dtype
assert out["inpt"].dtype == torch.float32
if scale:
assert self.was_scaled(out["inpt"])
else:
assert not self.was_scaled(out["inpt"])
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
def test_others_catch_all_and_none(self, make_input):
# make sure "others" works as a catch-all and that None means no conversion
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(dtype={datapoints.Mask: torch.int64, "others": None})(sample)
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype != mask_dtype
assert out["mask"].dtype == torch.int64
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
def test_typical_use_case(self, make_input):
# Typical use-case: want to convert dtype and scale for inpt and just dtype for masks.
# This just makes sure we now have a decent API for this
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(
dtype={type(sample["inpt"]): torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True
)(sample)
assert out["inpt"].dtype != inpt_dtype
assert out["inpt"].dtype == torch.float32
assert self.was_scaled(out["inpt"])
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype != mask_dtype
assert out["mask"].dtype == torch.int64
@pytest.mark.parametrize("make_input", (make_image_tensor, make_image, make_video))
def test_errors_warnings(self, make_input):
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
with pytest.raises(ValueError, match="No dtype was specified for"):
out = transforms.ToDtype(dtype={datapoints.Mask: torch.float32})(sample)
with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, datapoints.Image: torch.float32})
with pytest.warns(UserWarning, match="no scaling will be done"):
out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample)
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype
...@@ -364,16 +364,6 @@ DISPATCHER_INFOS = [ ...@@ -364,16 +364,6 @@ DISPATCHER_INFOS = [
xfail_jit_python_scalar_arg("std"), xfail_jit_python_scalar_arg("std"),
], ],
), ),
DispatcherInfo(
F.convert_dtype,
kernels={
datapoints.Image: F.convert_dtype_image_tensor,
datapoints.Video: F.convert_dtype_video,
},
test_marks=[
skip_dispatch_datapoint,
],
),
DispatcherInfo( DispatcherInfo(
F.uniform_temporal_subsample, F.uniform_temporal_subsample,
kernels={ kernels={
......
import decimal
import functools import functools
import itertools import itertools
...@@ -27,7 +26,6 @@ from common_utils import ( ...@@ -27,7 +26,6 @@ from common_utils import (
mark_framework_limitation, mark_framework_limitation,
TestMark, TestMark,
) )
from torch.utils._pytree import tree_map
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
...@@ -1566,7 +1564,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel): ...@@ -1566,7 +1564,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs): def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)( return type(output)(
F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype) F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True)
for output_pil in output for output_pil in output
) )
...@@ -1667,125 +1665,6 @@ KERNEL_INFOS.extend( ...@@ -1667,125 +1665,6 @@ KERNEL_INFOS.extend(
) )
def sample_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
):
if input_dtype.is_floating_point and output_dtype == torch.int64:
# conversion cannot be performed safely
continue
for image_loader in make_image_loaders(
sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[input_dtype]
):
yield ArgsKwargs(image_loader, dtype=output_dtype)
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
input_dtype = image.dtype
output_dtype = dtype
if output_dtype == input_dtype:
return image
def fn(value):
if input_dtype.is_floating_point:
if output_dtype.is_floating_point:
return value
else:
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
else:
input_max_value = torch.iinfo(input_dtype).max
if output_dtype.is_floating_point:
return float(decimal.Decimal(value) / input_max_value)
else:
output_max_value = torch.iinfo(output_dtype).max
if input_max_value > output_max_value:
factor = (input_max_value + 1) // (output_max_value + 1)
return value // factor
else:
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)
def reference_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
],
repeat=2,
):
if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
continue
if input_dtype.is_floating_point:
data = [0.0, 0.5, 1.0]
else:
max_value = torch.iinfo(input_dtype).max
data = [0, max_value // 2, max_value]
image = torch.tensor(data, dtype=input_dtype)
yield ArgsKwargs(image, dtype=output_dtype)
def sample_inputs_convert_dtype_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader)
skip_dtype_consistency = TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
)
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_dtype_image_tensor,
sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
reference_fn=reference_convert_dtype_image_tensor,
reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
test_marks=[
skip_dtype_consistency,
TestMark(
("TestKernels", "test_against_reference"),
pytest.mark.xfail(reason="Conversion overflows"),
condition=lambda args_kwargs: (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
),
),
],
),
KernelInfo(
F.convert_dtype_video,
sample_inputs_fn=sample_inputs_convert_dtype_video,
test_marks=[
skip_dtype_consistency,
],
),
]
)
def sample_inputs_uniform_temporal_subsample_video(): def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]): for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]):
yield ArgsKwargs(video_loader, num_samples=2) yield ArgsKwargs(video_loader, num_samples=2)
......
...@@ -39,7 +39,7 @@ from ._geometry import ( ...@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter, ScaleJitter,
TenCrop, TenCrop,
) )
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._temporal import UniformTemporalSubsample from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
......
...@@ -31,10 +31,13 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -31,10 +31,13 @@ class ConvertBoundingBoxFormat(Transform):
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
class ConvertDtype(Transform): class ConvertImageDtype(Transform):
"""[BETA] Convert input image or video to the given ``dtype`` and scale the values accordingly. """[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertDtype transform .. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image. This function does not support PIL Image.
...@@ -55,21 +58,14 @@ class ConvertDtype(Transform): ...@@ -55,21 +58,14 @@ class ConvertDtype(Transform):
_v1_transform_cls = _transforms.ConvertImageDtype _v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None: def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] return F.to_dtype(inpt, dtype=self.dtype, scale=True)
) -> Union[datapoints._TensorImageType, datapoints._TensorVideoType]:
return F.convert_dtype(inpt, self.dtype)
# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ConvertImageDtype = ConvertDtype
class ClampBoundingBox(Transform): class ClampBoundingBox(Transform):
......
...@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten ...@@ -11,7 +11,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size from ._utils import _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_box from .utils import has_any, is_simple_tensor, query_bounding_box
...@@ -225,36 +225,76 @@ class GaussianBlur(Transform): ...@@ -225,36 +225,76 @@ class GaussianBlur(Transform):
class ToDtype(Transform): class ToDtype(Transform):
"""[BETA] Converts the input to a specific dtype - this does not scale values. """[BETA] Converts the input to a specific dtype, optionally scaling the values for images or videos.
.. v2betastatus:: ToDtype transform .. v2betastatus:: ToDtype transform
.. note::
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args: Args:
dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to. dtype (``torch.dtype`` or dict of ``Datapoint`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-datapoint conversions, e.g. A dict can be passed to specify per-datapoint conversions, e.g.
``dtype={datapoints.Image: torch.float32, datapoints.Video: ``dtype={datapoints.Image: torch.float32, datapoints.Mask: torch.int64, "others":None}``. The "others"
torch.float64}``. key can be used as a catch-all for any other datapoint type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. Default: ``False``.
""" """
_transformed_types = (torch.Tensor,) _transformed_types = (torch.Tensor,)
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None: def __init__(
self, dtype: Union[torch.dtype, Dict[Union[Type, str], Optional[torch.dtype]]], scale: bool = False
) -> None:
super().__init__() super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype) if not isinstance(dtype, (dict, torch.dtype)):
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): raise ValueError(f"dtype must be a dict or a torch.dtype, got {type(dtype)} instead")
if (
isinstance(dtype, dict)
and torch.Tensor in dtype
and any(cls in dtype for cls in [datapoints.Image, datapoints.Video])
):
warnings.warn( warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input." "in case a `datapoints.Image` or `datapoints.Video` is present in the input."
) )
self.dtype = dtype self.dtype = dtype
self.scale = scale
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
dtype = self.dtype[type(inpt)] if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt
dtype: Optional[torch.dtype] = self.dtype
elif type(inpt) in self.dtype:
dtype = self.dtype[type(inpt)]
elif "others" in self.dtype:
dtype = self.dtype["others"]
else:
raise ValueError(
f"No dtype was specified for type {type(inpt)}. "
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, "
'you can use "others" as a catch-all key '
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
if dtype is None: if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
"scale was set to True but no dtype was specified for images or videos: no scaling will be done."
)
return inpt return inpt
return inpt.to(dtype=dtype)
return F.to_dtype(inpt, dtype=dtype, scale=self.scale)
class SanitizeBoundingBox(Transform): class SanitizeBoundingBox(Transform):
......
...@@ -5,10 +5,10 @@ from ._utils import is_simple_tensor # usort: skip ...@@ -5,10 +5,10 @@ from ._utils import is_simple_tensor # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_box, clamp_bounding_box,
convert_format_bounding_box, convert_format_bounding_box,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
convert_image_dtype, convert_image_dtype,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
get_dimensions_image_tensor, get_dimensions_image_tensor,
get_dimensions_image_pil, get_dimensions_image_pil,
get_dimensions, get_dimensions,
......
...@@ -9,7 +9,7 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -9,7 +9,7 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, convert_dtype_image_tensor from ._meta import _num_value_bits, to_dtype_image_tensor
from ._utils import is_simple_tensor from ._utils import is_simple_tensor
...@@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image return image
orig_dtype = image.dtype orig_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.float32) image = to_dtype_image_tensor(image, torch.float32, scale=True)
image = _rgb_to_hsv(image) image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3) h, s, v = image.unbind(dim=-3)
...@@ -359,7 +359,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -359,7 +359,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
image = torch.stack((h, s, v), dim=-3) image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image) image_hue_adj = _hsv_to_rgb(image)
return convert_dtype_image_tensor(image_hue_adj, orig_dtype) return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True)
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
...@@ -393,7 +393,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 ...@@ -393,7 +393,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
# The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer).
# Since the gamma is non-negative, the output remains at [0, 1] scale. # Since the gamma is non-negative, the output remains at [0, 1] scale.
if not torch.is_floating_point(image): if not torch.is_floating_point(image):
output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) output = to_dtype_image_tensor(image, torch.float32, scale=True).pow_(gamma)
else: else:
output = image.pow(gamma) output = image.pow(gamma)
...@@ -402,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 ...@@ -402,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
# of the output can go beyond [0, 1]. # of the output can go beyond [0, 1].
output = output.mul_(gain).clamp_(0.0, 1.0) output = output.mul_(gain).clamp_(0.0, 1.0)
return convert_dtype_image_tensor(output, image.dtype) return to_dtype_image_tensor(output, image.dtype, scale=True)
adjust_gamma_image_pil = _FP.adjust_gamma adjust_gamma_image_pil = _FP.adjust_gamma
...@@ -565,7 +565,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -565,7 +565,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base. # by far the most common, we choose it as base.
output_dtype = image.dtype output_dtype = image.dtype
image = convert_dtype_image_tensor(image, torch.uint8) image = to_dtype_image_tensor(image, torch.uint8, scale=True)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram. # corresponds to adding 1 to index 127 in the histogram.
...@@ -616,7 +616,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -616,7 +616,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
output = torch.where(valid_equalization, equalized_image, image) output = torch.where(valid_equalization, equalized_image, image)
return convert_dtype_image_tensor(output, output_dtype) return to_dtype_image_tensor(output, output_dtype, scale=True)
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
......
...@@ -296,9 +296,12 @@ def _num_value_bits(dtype: torch.dtype) -> int: ...@@ -296,9 +296,12 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype: if image.dtype == dtype:
return image return image
elif not scale:
return image.to(dtype)
float_input = image.is_floating_point() float_input = image.is_floating_point()
if torch.jit.is_scripting(): if torch.jit.is_scripting():
...@@ -345,30 +348,28 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f ...@@ -345,30 +348,28 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is # We encourage users to use to_dtype() instead but we keep this for BC
# prevalent and well understood. Thus, we just alias it without deprecating the old name. def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
convert_image_dtype = convert_dtype_image_tensor return to_dtype_image_tensor(image, dtype=dtype, scale=True)
def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return convert_dtype_image_tensor(video, dtype) return to_dtype_image_tensor(video, dtype, scale=scale)
def convert_dtype( def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(convert_dtype) _log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return convert_dtype_image_tensor(inpt, dtype) return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image): elif isinstance(inpt, datapoints.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype) output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output) return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype) output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output) return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else: else:
raise TypeError( raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment