Unverified Commit d27cb7c0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Ensure input type of normalize is float. (#3621)

parent 226126b8
...@@ -446,12 +446,15 @@ class Tester(TransformsTester): ...@@ -446,12 +446,15 @@ class Tester(TransformsTester):
) )
def test_normalize(self): def test_normalize(self):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = self._create_data(26, 34, device=self.device) tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
fn(tensor)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
tensor = tensor.to(dtype=torch.float32) / 255.0 tensor = tensor.to(dtype=torch.float32) / 255.0
# test for class interface # test for class interface
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor) self._test_transform_vs_scripted(fn, scripted_fn, tensor)
......
...@@ -297,7 +297,7 @@ def to_pil_image(pic, mode=None): ...@@ -297,7 +297,7 @@ def to_pil_image(pic, mode=None):
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
"""Normalize a tensor image with mean and standard deviation. """Normalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image. This transform does not support PIL Image.
.. note:: .. note::
...@@ -306,7 +306,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -306,7 +306,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
See :class:`~torchvision.transforms.Normalize` for more details. See :class:`~torchvision.transforms.Normalize` for more details.
Args: Args:
tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized. tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel. mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel. std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace. inplace(bool,optional): Bool to make this operation inplace.
...@@ -317,6 +317,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -317,6 +317,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
if not isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if not tensor.is_floating_point():
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
if tensor.ndim < 3: if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size())) '{}.'.format(tensor.size()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment