Unverified Commit 9c112935 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add tests and proper support for videos in `ConvertImageDtype` (#6783)

* add KernelInfo

* split dtype and device consistency tests

* add proper support for video

* fix tests and add DispatcherInfo

* add aliases

* cleanup

* fix typo
parent 9d322cac
......@@ -22,7 +22,7 @@ from torch.testing._comparison import (
UnsupportedInputs,
)
from torchvision.prototype import features
from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value
__all__ = [
......@@ -97,8 +97,8 @@ class PILImagePair(TensorLikePair):
def _equalize_attributes(self, actual, expected):
if actual.dtype != expected.dtype:
dtype = torch.promote_types(actual.dtype, expected.dtype)
actual = convert_image_dtype(actual, dtype)
expected = convert_image_dtype(expected, dtype)
actual = convert_dtype_image_tensor(actual, dtype)
expected = convert_dtype_image_tensor(expected, dtype)
return super()._equalize_attributes(actual, expected)
......
......@@ -416,4 +416,14 @@ DISPATCHER_INFOS = [
skip_dispatch_feature,
],
),
DispatcherInfo(
F.convert_dtype,
kernels={
features.Image: F.convert_dtype_image_tensor,
features.Video: F.convert_dtype_video,
},
test_marks=[
skip_dispatch_feature,
],
),
]
......@@ -1979,7 +1979,7 @@ KERNEL_INFOS.extend(
)
def sample_inputs_convert_image_dtype():
def sample_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
):
......@@ -1992,10 +1992,8 @@ def sample_inputs_convert_image_dtype():
):
yield ArgsKwargs(image_loader, dtype=output_dtype)
yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8)
def reference_convert_image_dtype(image, dtype=torch.float):
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
input_dtype = image.dtype
output_dtype = dtype
......@@ -2026,7 +2024,7 @@ def reference_convert_image_dtype(image, dtype=torch.float):
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)
def reference_inputs_convert_image_dtype():
def reference_inputs_convert_dtype_image_tensor():
for input_dtype, output_dtype in itertools.product(
[
torch.uint8,
......@@ -2055,24 +2053,32 @@ def reference_inputs_convert_image_dtype():
yield ArgsKwargs(image, dtype=output_dtype)
def sample_inputs_convert_dtype_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)
_common_convert_dtype_marks = [
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
),
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %')}:UserWarning"),
),
]
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_image_dtype,
sample_inputs_fn=sample_inputs_convert_image_dtype,
reference_fn=reference_convert_image_dtype,
reference_inputs_fn=reference_inputs_convert_image_dtype,
F.convert_dtype_image_tensor,
sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
reference_fn=reference_convert_dtype_image_tensor,
reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
test_marks=[
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"),
),
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype
!= args_kwargs.kwargs.get("dtype", torch.float32),
),
*_common_convert_dtype_marks,
TestMark(
("TestKernels", "test_against_reference"),
pytest.mark.xfail(reason="Conversion overflows"),
......@@ -2080,10 +2086,6 @@ KERNEL_INFOS.extend(
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and args_kwargs.kwargs["dtype"] == torch.int64
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
......@@ -2091,5 +2093,10 @@ KERNEL_INFOS.extend(
),
],
),
KernelInfo(
F.convert_dtype_video,
sample_inputs_fn=sample_inputs_convert_dtype_video,
test_marks=_common_convert_dtype_marks,
),
]
)
......@@ -92,7 +92,7 @@ class TestSmoke:
transforms.RandomErasing(p=1.0),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
transforms.ConvertDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
......
......@@ -153,7 +153,7 @@ CONSISTENCY_CONFIGS = [
),
),
ConsistencyConfig(
prototype_transforms.ConvertImageDtype,
prototype_transforms.ConvertDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
......
......@@ -307,6 +307,7 @@ class TestDispatchers:
(F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil),
(F.elastic_transform, F.elastic),
(F.convert_image_dtype, F.convert_dtype_image_tensor),
]
],
)
......
......@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype
from ._misc import (
GaussianBlur,
Identity,
......
......@@ -25,7 +25,7 @@ class ConvertBoundingBoxFormat(Transform):
return features.BoundingBox.wrap_like(inpt, output, format=params["format"])
class ConvertImageDtype(Transform):
class ConvertDtype(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
......@@ -35,12 +35,12 @@ class ConvertImageDtype(Transform):
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]:
# TODO: the `inpt.as_subclass(torch.Tensor)` call can be removed as soon as we have a proper dispatcher that
# handles this. See https://github.com/pytorch/vision/pull/6783 for details.
output = F.convert_image_dtype(inpt.as_subclass(torch.Tensor), dtype=self.dtype)
return (
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
)
return F.convert_dtype(inpt, self.dtype)
# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ConvertImageDtype = ConvertDtype
class ConvertColorSpace(Transform):
......
......@@ -8,6 +8,10 @@ from ._meta import (
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
convert_image_dtype,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
......@@ -162,7 +166,6 @@ from ._misc import (
normalize_video,
)
from ._type_conversion import (
convert_image_dtype,
decode_image_with_pil,
decode_video_with_av,
pil_to_tensor,
......
......@@ -285,3 +285,99 @@ def convert_color_space(
return features.Video.wrap_like(inpt, output, color_space=color_space)
else:
return convert_color_space_image_pil(inpt, color_space)
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).div_(_FT._max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
# The bitshift kernel is not vectorized
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
max_value_input = float(_FT._max_value(dtype))
max_value_output = float(_FT._max_value(image.dtype))
factor = int((max_value_input + 1) // (max_value_output + 1))
return image.to(dtype).mul_(factor)
# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
convert_image_dtype = convert_dtype_image_tensor
def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
return convert_dtype_image_tensor(video, dtype)
def convert_dtype(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return convert_dtype_image_tensor(inpt, dtype)
elif isinstance(inpt, features.Image):
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
return features.Image.wrap_like(inpt, output)
else: # isinstance(inpt, features.Video):
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
return features.Video.wrap_like(inpt, output)
......@@ -7,7 +7,7 @@ import torch
from torchvision.io.video import read_video
from torchvision.prototype import features
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F, functional_tensor as _FT
from torchvision.transforms import functional as _F
@torch.jit.unused
......@@ -41,78 +41,3 @@ pil_to_tensor = _F.pil_to_tensor
# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
to_pil_image = to_image_pil
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor Image")
if image.dtype == dtype:
return image
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).div_(_FT._max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
# The bitshift kernel is not vectorized
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
max_value_input = float(_FT._max_value(dtype))
max_value_output = float(_FT._max_value(image.dtype))
factor = int((max_value_input + 1) // (max_value_output + 1))
return image.to(dtype).mul_(factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment