Unverified Commit 631ff912 authored by Alessio Falai's avatar Alessio Falai Committed by GitHub
Browse files

add float check in rcnn normalize (#3266)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent a2908d88
...@@ -58,6 +58,13 @@ class Tester(unittest.TestCase): ...@@ -58,6 +58,13 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes'])) self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes']))
self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes'])) self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes']))
def test_not_float_normalize(self):
transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3))
image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)]
targets = [{'boxes': torch.rand(3, 4)}]
with self.assertRaises(TypeError):
out = transform(image, targets) # noqa: F841
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -117,6 +117,11 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -117,6 +117,11 @@ class GeneralizedRCNNTransform(nn.Module):
return image_list, targets return image_list, targets
def normalize(self, image): def normalize(self, image):
if not image.is_floating_point():
raise TypeError(
f"Expected input images to be of floating type (in range [0, 1]), "
f"but found type {image.dtype} instead"
)
dtype, device = image.dtype, image.device dtype, device = image.dtype, image.device
mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device) mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
std = torch.as_tensor(self.image_std, dtype=dtype, device=device) std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment