"clients/vscode:/vscode.git/clone" did not exist on "2c5df5d2affc442f30a21ee628bfee89a3666cb5"
Unverified Commit 2f433e0a authored by Phoenix Meadowlark's avatar Phoenix Meadowlark Committed by GitHub
Browse files

Improved error messages for transforms.functional.normalize(). (#1915)

* Improved error messages for transforms.functional.normalize().

Split the original TypeError into 1. a TypeError if `tensor` is not a torch.Tensor and 2. a ValueError if `tensor` does not have the correct dimensionality.

Added more detail to the error for when `tensor` has the wrong dimension to make it easier to diagnose. This is useful when this function isn't called directly by the user (e.g. when the user uses transforms.Normalize and can't directly see this functions doc string).

Deleted hanging function `_is_tensor_image()`. It isn't used in this file and isn't used internally anywhere else in torchvision that I can see. (Some users will have used it despite the underscore prefix, but a quick google search for "F._is_tensor_image" suggests this is rare).

* Value checking to prevent division by zero runtime crashes.

Added a ValueError to check for and avoid division by zero in `div_`. Not preventing the call leads to runtime crashes, at least in some environments.

* Fixed div by zero check for non-scalar inputs.
parent 2937d77d
...@@ -28,10 +28,6 @@ def _is_pil_image(img): ...@@ -28,10 +28,6 @@ def _is_pil_image(img):
return isinstance(img, Image.Image) return isinstance(img, Image.Image)
def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy(img): def _is_numpy(img):
return isinstance(img, np.ndarray) return isinstance(img, np.ndarray)
...@@ -200,8 +196,12 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -200,8 +196,12 @@ def normalize(tensor, mean, std, inplace=False):
Returns: Returns:
Tensor: Normalized Tensor image. Tensor: Normalized Tensor image.
""" """
if not _is_tensor_image(tensor): if not torch.is_tensor(tensor):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor)))
if tensor.ndimension() != 3:
raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))
if not inplace: if not inplace:
tensor = tensor.clone() tensor = tensor.clone()
...@@ -209,6 +209,8 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -209,6 +209,8 @@ def normalize(tensor, mean, std, inplace=False):
dtype = tensor.dtype dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment