Unverified Commit ed115b34 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Normalize floating point cast (#27249)

* Normalize image - cast input images to float32.

This is done if the input image isn't of floating type. Issues can occur when do_rescale=False is set in an image processor. When this happens, the image passed to the call is of type uint8 becuase of the type casting that happens in resize because of the PIL image library. As the mean and std values are cast to match the image dtype, this can cause NaNs and infs to appear in the normalized image, as the floating values being used to divide the image are now set to 0.

The reason the mean and std values are cast is because previously they were set as float32 by default. However, if the input image was of type float16, the normalization would result in the image being upcast to float32 too.

* Add tests

* Remove float32 cast
parent e1c3ac25
...@@ -376,6 +376,11 @@ def normalize( ...@@ -376,6 +376,11 @@ def normalize(
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format) channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = image.shape[channel_axis] num_channels = image.shape[channel_axis]
# We cast to float32 to avoid errors that can occur when subtracting uint8 values.
# We preserve the original dtype if it is a float type to prevent upcasting float16.
if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32)
if isinstance(mean, Iterable): if isinstance(mean, Iterable):
if len(mean) != num_channels: if len(mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
......
...@@ -302,7 +302,7 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -302,7 +302,7 @@ class ImageTransformsTester(unittest.TestCase):
normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first") normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first")
self.assertIsInstance(normalized_image, np.ndarray) self.assertIsInstance(normalized_image, np.ndarray)
self.assertEqual(normalized_image.shape, (3, 224, 224)) self.assertEqual(normalized_image.shape, (3, 224, 224))
self.assertTrue(np.allclose(normalized_image, expected_image)) self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6))
# Test image with 4 channels is normalized correctly # Test image with 4 channels is normalized correctly
image = np.random.randint(0, 256, (224, 224, 4)) / 255 image = np.random.randint(0, 256, (224, 224, 4)) / 255
...@@ -310,9 +310,42 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -310,9 +310,42 @@ class ImageTransformsTester(unittest.TestCase):
std = (0.1, 0.2, 0.3, 0.4) std = (0.1, 0.2, 0.3, 0.4)
expected_image = (image - mean) / std expected_image = (image - mean) / std
self.assertTrue( self.assertTrue(
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image) np.allclose(
normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image, atol=1e-6
)
) )
# Test float32 image input keeps float32 dtype
image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float32) / 255
mean = (0.5, 0.6, 0.7)
std = (0.1, 0.2, 0.3)
expected_image = ((image - mean) / std).astype(np.float32)
normalized_image = normalize(image, mean=mean, std=std)
self.assertEqual(normalized_image.dtype, np.float32)
self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6))
# Test float16 image input keeps float16 dtype
image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255
mean = (0.5, 0.6, 0.7)
std = (0.1, 0.2, 0.3)
# The mean and std are cast to match the dtype of the input image
cast_mean = np.array(mean, dtype=np.float16)
cast_std = np.array(std, dtype=np.float16)
expected_image = (image - cast_mean) / cast_std
normalized_image = normalize(image, mean=mean, std=std)
self.assertEqual(normalized_image.dtype, np.float16)
self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6))
# Test int image input is converted to float32
image = np.random.randint(0, 2, (224, 224, 3), dtype=np.uint8)
mean = (0.5, 0.6, 0.7)
std = (0.1, 0.2, 0.3)
expected_image = (image.astype(np.float32) - mean) / std
normalized_image = normalize(image, mean=mean, std=std)
self.assertEqual(normalized_image.dtype, np.float32)
self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6))
def test_center_crop(self): def test_center_crop(self):
image = np.random.randint(0, 256, (3, 224, 224)) image = np.random.randint(0, 256, (3, 224, 224))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment