Unverified Commit 57c8de7f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Recoded _max_value method using a dictionary (#5566)

* Removed _max_value method and added a dictionary

Related to https://github.com/pytorch/vision/issues/5502

* Addressed failing tests and restored _max_value method

* Added xfailing test to switch quicker

* Switch to if/else impl
parent d8654bb0
......@@ -1486,6 +1486,15 @@ def test_max_value(dtype):
# self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)
@pytest.mark.xfail(
reason="torch.iinfo() is not supported by torchscript. See https://github.com/pytorch/pytorch/issues/41492."
)
def test_max_value_iinfo():
@torch.jit.script
def max_value(image: torch.Tensor) -> int:
return 1 if image.is_floating_point() else torch.iinfo(image.dtype).max
@pytest.mark.parametrize("should_vflip", [True, False])
@pytest.mark.parametrize("single_dim", [True, False])
def test_ten_crop(should_vflip, single_dim):
......
......@@ -44,22 +44,19 @@ def get_image_num_channels(img: Tensor) -> int:
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
def _max_value(dtype: torch.dtype) -> float:
# TODO: replace this method with torch.iinfo when it gets torchscript support.
# https://github.com/pytorch/pytorch/issues/41492
a = torch.tensor(2, dtype=dtype)
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
bits = 1
max_value = torch.tensor(-signed, dtype=torch.long)
while True:
next_value = a.pow(bits - signed).sub(1)
if next_value > max_value:
max_value = next_value
bits *= 2
else:
break
return max_value.item()
def _max_value(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return int(2 ** 8) - 1
elif dtype == torch.int8:
return int(2 ** 7) - 1
elif dtype == torch.int16:
return int(2 ** 15) - 1
elif dtype == torch.int32:
return int(2 ** 31) - 1
elif dtype == torch.int64:
return int(2 ** 63) - 1
else:
return 1
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
......@@ -91,11 +88,11 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = _max_value(dtype)
max_val = float(_max_value(dtype))
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = _max_value(image.dtype)
input_max = float(_max_value(image.dtype))
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
......@@ -103,7 +100,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
image = image.to(dtype)
return image / input_max
output_max = _max_value(dtype)
output_max = float(_max_value(dtype))
# int to int
if input_max > output_max:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment