Commit ae2cb6ec authored by Geovanni Zhang's avatar Geovanni Zhang Committed by Francisco Massa
Browse files

fix:error message of to_tensor (#1000)

* fix:error message of to_tensor

The error "pic should be PIL Image or ndarray. Got '<numpy.ndarray>'" is confusing.

* fix:a clearer function name

_is_numpy_image is clearer than _is_numpy_image_dim

* fix:add a test case

Add a test case in test/test_transforms.py to test the error message

* fix:pass ci check

* fix:wrong random matrix
parent 174e135d
...@@ -434,6 +434,13 @@ class Tester(unittest.TestCase): ...@@ -434,6 +434,13 @@ class Tester(unittest.TestCase):
height, width = 4, 4 height, width = 4, 4
trans = transforms.ToTensor() trans = transforms.ToTensor()
with self.assertRaises(TypeError):
trans(np.random.rand(1, height, width).tolist())
with self.assertRaises(ValueError):
trans(np.random.rand(height))
trans(np.random.rand(1, 1, height, width))
for channels in test_channels: for channels in test_channels:
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255) input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data) img = transforms.ToPILImage()(input_data)
......
...@@ -31,8 +31,12 @@ def _is_tensor_image(img): ...@@ -31,8 +31,12 @@ def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3 return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy(img):
return isinstance(img, np.ndarray)
def _is_numpy_image(img): def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) return img.ndim in {2, 3}
def to_tensor(pic): def to_tensor(pic):
...@@ -46,9 +50,12 @@ def to_tensor(pic): ...@@ -46,9 +50,12 @@ def to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not(_is_pil_image(pic) or _is_numpy_image(pic)): if not(_is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
if isinstance(pic, np.ndarray): if isinstance(pic, np.ndarray):
# handle numpy array # handle numpy array
if pic.ndim == 2: if pic.ndim == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment