Commit ae2cb6ec authored by Geovanni Zhang's avatar Geovanni Zhang Committed by Francisco Massa
Browse files

fix:error message of to_tensor (#1000)

* fix:error message of to_tensor

The error "pic should be PIL Image or ndarray. Got '<numpy.ndarray>'" is confusing.

* fix:a clearer function name

_is_numpy_image is clearer than _is_numpy_image_dim

* fix:add a test case

Add a test case in test/test_transforms.py to test the error message

* fix:pass ci check

* fix:wrong random matrix
parent 174e135d
......@@ -434,6 +434,13 @@ class Tester(unittest.TestCase):
height, width = 4, 4
trans = transforms.ToTensor()
with self.assertRaises(TypeError):
trans(np.random.rand(1, height, width).tolist())
with self.assertRaises(ValueError):
trans(np.random.rand(height))
trans(np.random.rand(1, 1, height, width))
for channels in test_channels:
input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
img = transforms.ToPILImage()(input_data)
......
......@@ -31,8 +31,12 @@ def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy(img):
return isinstance(img, np.ndarray)
def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
return img.ndim in {2, 3}
def to_tensor(pic):
......@@ -46,9 +50,12 @@ def to_tensor(pic):
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
if not(_is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment