Unverified Commit 211563fb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve perf on convert_image_dtype and add tests (#6795)

* improve perf on convert_image_dtype and add tests

* add reference tests

* use bitshifts for int to int

* revert bitshifts for int to int upscale

* fix warning ignore
parent 7a62a545
import decimal
import functools
import itertools
import math
......@@ -21,6 +22,7 @@ from prototype_common_utils import (
mark_framework_limitation,
TestMark,
)
from torch.utils._pytree import tree_map
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value
......@@ -1947,3 +1949,119 @@ KERNEL_INFOS.extend(
),
]
)
def sample_inputs_convert_image_dtype():
for input_dtype, output_dtype in itertools.product(
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
):
if input_dtype.is_floating_point and output_dtype == torch.int64:
# conversion cannot be performed safely
continue
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype]
):
yield ArgsKwargs(image_loader, dtype=output_dtype)
yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8)
def reference_convert_image_dtype(image, dtype=torch.float):
input_dtype = image.dtype
output_dtype = dtype
if output_dtype == input_dtype:
return image
def fn(value):
if input_dtype.is_floating_point:
if output_dtype.is_floating_point:
return value
else:
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
else:
input_max_value = torch.iinfo(input_dtype).max
if output_dtype.is_floating_point:
return float(decimal.Decimal(value) / input_max_value)
else:
output_max_value = torch.iinfo(output_dtype).max
if input_max_value > output_max_value:
factor = (input_max_value + 1) // (output_max_value + 1)
return value // factor
else:
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)
def reference_inputs_convert_image_dtype():
for input_dtype, output_dtype in itertools.product(
[
torch.uint8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.float64,
torch.bfloat16,
],
repeat=2,
):
if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
continue
if input_dtype.is_floating_point:
data = [0.0, 0.5, 1.0]
else:
max_value = torch.iinfo(input_dtype).max
data = [0, max_value // 2, max_value]
image = torch.tensor(data, dtype=input_dtype)
yield ArgsKwargs(image, dtype=output_dtype)
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_image_dtype,
sample_inputs_fn=sample_inputs_convert_image_dtype,
reference_fn=reference_convert_image_dtype,
reference_inputs_fn=reference_inputs_convert_image_dtype,
test_marks=[
TestMark(
("TestKernels", "test_scripted_vs_eager"),
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"),
),
TestMark(
("TestKernels", "test_dtype_and_device_consistency"),
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
condition=lambda args_kwargs: args_kwargs.args[0].dtype
!= args_kwargs.kwargs.get("dtype", torch.float32),
),
TestMark(
("TestKernels", "test_against_reference"),
pytest.mark.xfail(reason="Conversion overflows"),
condition=lambda args_kwargs: (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and not args_kwargs.kwargs["dtype"].is_floating_point
)
or (
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
and args_kwargs.kwargs["dtype"] == torch.int64
)
or (
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
and args_kwargs.kwargs["dtype"] == torch.float16
),
),
],
),
]
)
......@@ -26,6 +26,20 @@ def script(fn):
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
args_kwargs = list(args_kwargs_fn(info))
idx_field_len = len(str(len(args_kwargs)))
return [
pytest.param(
info,
args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
id=f"{info.id}-{idx:0{idx_field_len}}",
)
for idx, args_kwargs_ in enumerate(args_kwargs)
]
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None):
if condition is None:
......@@ -49,18 +63,7 @@ def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=No
if not condition(info):
continue
args_kwargs = list(args_kwargs_fn(info))
idx_field_len = len(str(len(args_kwargs)))
for idx, args_kwargs_ in enumerate(args_kwargs):
argvalues.append(
pytest.param(
info,
args_kwargs_,
marks=info.get_marks(test_id, args_kwargs_),
id=f"{info.id}-{idx:0{idx_field_len}}",
)
)
argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
return pytest.mark.parametrize(argnames, argvalues)(test_fn)
......@@ -232,7 +235,6 @@ class TestDispatchers:
[
F.clamp_bounding_box,
F.convert_color_space,
F.convert_image_dtype,
F.get_dimensions,
F.get_image_num_channels,
F.get_image_size,
......@@ -312,6 +314,24 @@ def test_alias(alias, target):
assert alias is target
@pytest.mark.parametrize(
("info", "args_kwargs"),
make_info_args_kwargs_params(
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype),
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device):
(input, *other_args), kwargs = args_kwargs.load(device)
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
output = info.kernel(input, dtype)
assert output.dtype == dtype
assert output.device == input.device
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
......
......@@ -7,7 +7,7 @@ import torch
from torchvision.io.video import read_video
from torchvision.prototype import features
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F
from torchvision.transforms import functional as _F, functional_tensor as _FT
@torch.jit.unused
......@@ -42,4 +42,77 @@ pil_to_tensor = _F.pil_to_tensor
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
to_pil_image = to_image_pil
convert_image_dtype = _F.convert_image_dtype
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor Image")
if image.dtype == dtype:
return image
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_FT._max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).div_(_FT._max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
# The bitshift kernel is not vectorized
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
max_value_input = float(_FT._max_value(dtype))
max_value_output = float(_FT._max_value(image.dtype))
factor = int((max_value_input + 1) // (max_value_output + 1))
return image.to(dtype).mul_(factor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment