You need to sign in or sign up before continuing.
Unverified Commit dab47572 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix potential overflow in convert_image_dtype (#3107)

We could have errors such as aten/src/ATen/native/cpu/PowKernel.cpp:41:5:  runtime error: 5.7896e+76 is outside the range of representable values of type 'float'
parent 0a75a0c1
...@@ -105,7 +105,6 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -105,7 +105,6 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
return result.to(dtype) return result.to(dtype)
else: else:
input_max = _max_value(image.dtype) input_max = _max_value(image.dtype)
output_max = _max_value(dtype)
# int to float # int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it # TODO: replace with dtype.is_floating_point when torchscript supports it
...@@ -113,6 +112,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -113,6 +112,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
image = image.to(dtype) image = image.to(dtype)
return image / input_max return image / input_max
output_max = _max_value(dtype)
# int to int # int to int
if input_max > output_max: if input_max > output_max:
# factor should be forced to int for torch jit script # factor should be forced to int for torch jit script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment