Commit 2b2aa9c7 authored by Vishwak Srinivasan's avatar Vishwak Srinivasan Committed by Soumith Chintala
Browse files

Add descriptions for Transform objects (#380)

* add __repr__ for all transform objects

* add tests for __repr__
parent 9889de1d
...@@ -113,6 +113,11 @@ class Tester(unittest.TestCase): ...@@ -113,6 +113,11 @@ class Tester(unittest.TestCase):
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
results = transform(img) results = transform(img)
expected_output = five_crop(img) expected_output = five_crop(img)
# Checking if FiveCrop and TenCrop can be printed as string
transform.__repr__()
five_crop.__repr__()
if should_vflip: if should_vflip:
vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM) vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
expected_output += five_crop(vflipped_img) expected_output += five_crop(vflipped_img)
...@@ -226,6 +231,9 @@ class Tester(unittest.TestCase): ...@@ -226,6 +231,9 @@ class Tester(unittest.TestCase):
assert output.size[0] == width + padding[0] + padding[2] assert output.size[0] == width + padding[0] + padding[2]
assert output.size[1] == height + padding[1] + padding[3] assert output.size[1] == height + padding[1] + padding[3]
# Checking if Padding can be printed as string
transforms.Pad(padding).__repr__()
def test_pad_raises_with_invalid_pad_sequence_len(self): def test_pad_raises_with_invalid_pad_sequence_len(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
transforms.Pad(()) transforms.Pad(())
...@@ -247,6 +255,9 @@ class Tester(unittest.TestCase): ...@@ -247,6 +255,9 @@ class Tester(unittest.TestCase):
y = trans(x) y = trans(x)
assert (y.equal(x)) assert (y.equal(x))
# Checking if Lambda can be printed as string
trans.__repr__()
def test_to_tensor(self): def test_to_tensor(self):
test_channels = [1, 3, 4] test_channels = [1, 3, 4]
height, width = 4, 4 height, width = 4, 4
...@@ -280,6 +291,9 @@ class Tester(unittest.TestCase): ...@@ -280,6 +291,9 @@ class Tester(unittest.TestCase):
transforms.ToTensor(), transforms.ToTensor(),
]) ])
# Checking if Compose, Resize and ToTensor can be printed as string
trans.__repr__()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
...@@ -296,6 +310,9 @@ class Tester(unittest.TestCase): ...@@ -296,6 +310,9 @@ class Tester(unittest.TestCase):
transforms.ToTensor(), transforms.ToTensor(),
]) ])
# Checking if Compose, CenterCrop and ToTensor can be printed as string
trans.__repr__()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER)) output = trans(accimage.Image(GRACE_HOPPER))
...@@ -375,6 +392,9 @@ class Tester(unittest.TestCase): ...@@ -375,6 +392,9 @@ class Tester(unittest.TestCase):
for mode in [None, 'RGB', 'HSV', 'YCbCr']: for mode in [None, 'RGB', 'HSV', 'YCbCr']:
verify_img_data(img_data, mode) verify_img_data(img_data, mode)
# Checking if ToPILImage can be printed as string
transforms.ToPILImage().__repr__()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 channel images # should raise if we try a mode for 4 or 1 channel images
transforms.ToPILImage(mode='RGBA')(img_data) transforms.ToPILImage(mode='RGBA')(img_data)
...@@ -450,6 +470,9 @@ class Tester(unittest.TestCase): ...@@ -450,6 +470,9 @@ class Tester(unittest.TestCase):
random.setstate(random_state) random.setstate(random_state)
assert p_value > 0.0001 assert p_value > 0.0001
# Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available') @unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip(self): def test_random_horizontal_flip(self):
random_state = random.getstate() random_state = random.getstate()
...@@ -468,6 +491,9 @@ class Tester(unittest.TestCase): ...@@ -468,6 +491,9 @@ class Tester(unittest.TestCase):
random.setstate(random_state) random.setstate(random_state)
assert p_value > 0.0001 assert p_value > 0.0001
# Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__()
@unittest.skipIf(stats is None, 'scipt.stats is not available') @unittest.skipIf(stats is None, 'scipt.stats is not available')
def test_normalize(self): def test_normalize(self):
def samples_from_standard_normal(tensor): def samples_from_standard_normal(tensor):
...@@ -484,6 +510,9 @@ class Tester(unittest.TestCase): ...@@ -484,6 +510,9 @@ class Tester(unittest.TestCase):
assert samples_from_standard_normal(normalized) assert samples_from_standard_normal(normalized)
random.setstate(random_state) random.setstate(random_state)
# Checking if Normalize can be printed as string
transforms.Normalize(mean, std).__repr__()
def test_adjust_brightness(self): def test_adjust_brightness(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
...@@ -645,6 +674,9 @@ class Tester(unittest.TestCase): ...@@ -645,6 +674,9 @@ class Tester(unittest.TestCase):
y_pil_2 = color_jitter(x_pil_2) y_pil_2 = color_jitter(x_pil_2)
assert y_pil_2.mode == x_pil_2.mode assert y_pil_2.mode == x_pil_2.mode
# Checking if ColorJitter can be printed as string
color_jitter.__repr__()
def test_linear_transformation(self): def test_linear_transformation(self):
x = torch.randn(250, 10, 10, 3) x = torch.randn(250, 10, 10, 3)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3)) flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
...@@ -664,6 +696,9 @@ class Tester(unittest.TestCase): ...@@ -664,6 +696,9 @@ class Tester(unittest.TestCase):
cov = np.dot(xwhite, xwhite.T) / x.size(0) cov = np.dot(xwhite, xwhite.T) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3) assert np.allclose(cov, np.identity(1), rtol=1e-3)
# Checking if LinearTransformation can be printed as string
whitening.__repr__()
def test_rotate(self): def test_rotate(self):
x = np.zeros((100, 100, 3), dtype=np.uint8) x = np.zeros((100, 100, 3), dtype=np.uint8)
x[40, 40] = [255, 255, 255] x[40, 40] = [255, 255, 255]
...@@ -714,6 +749,9 @@ class Tester(unittest.TestCase): ...@@ -714,6 +749,9 @@ class Tester(unittest.TestCase):
angle = t.get_params(t.degrees) angle = t.get_params(t.degrees)
assert angle > -10 and angle < 10 assert angle > -10 and angle < 10
# Checking if RandomRotation can be printed as string
t.__repr__()
def test_to_grayscale(self): def test_to_grayscale(self):
"""Unit tests for grayscale transform""" """Unit tests for grayscale transform"""
...@@ -761,6 +799,9 @@ class Tester(unittest.TestCase): ...@@ -761,6 +799,9 @@ class Tester(unittest.TestCase):
np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_4[:, :, 0]) np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])
# Checking if Grayscale can be printed as string
trans4.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available') @unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_grayscale(self): def test_random_grayscale(self):
"""Unit tests for random grayscale transform""" """Unit tests for random grayscale transform"""
...@@ -851,6 +892,9 @@ class Tester(unittest.TestCase): ...@@ -851,6 +892,9 @@ class Tester(unittest.TestCase):
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
np.testing.assert_equal(gray_np, gray_np_3) np.testing.assert_equal(gray_np, gray_np_3)
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -42,6 +42,14 @@ class Compose(object): ...@@ -42,6 +42,14 @@ class Compose(object):
img = t(img) img = t(img)
return img return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor(object): class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
...@@ -60,6 +68,9 @@ class ToTensor(object): ...@@ -60,6 +68,9 @@ class ToTensor(object):
""" """
return F.to_tensor(pic) return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class ToPILImage(object): class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image. """Convert a tensor or an ndarray to PIL Image.
...@@ -91,6 +102,9 @@ class ToPILImage(object): ...@@ -91,6 +102,9 @@ class ToPILImage(object):
""" """
return F.to_pil_image(pic, self.mode) return F.to_pil_image(pic, self.mode)
def __repr__(self):
return self.__class__.__name__ + '({0})'.format(self.mode)
class Normalize(object): class Normalize(object):
"""Normalize an tensor image with mean and standard deviation. """Normalize an tensor image with mean and standard deviation.
...@@ -117,6 +131,9 @@ class Normalize(object): ...@@ -117,6 +131,9 @@ class Normalize(object):
""" """
return F.normalize(tensor, self.mean, self.std) return F.normalize(tensor, self.mean, self.std)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Resize(object): class Resize(object):
"""Resize the input PIL Image to the given size. """Resize the input PIL Image to the given size.
...@@ -146,6 +163,9 @@ class Resize(object): ...@@ -146,6 +163,9 @@ class Resize(object):
""" """
return F.resize(img, self.size, self.interpolation) return F.resize(img, self.size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class Scale(Resize): class Scale(Resize):
""" """
...@@ -182,6 +202,9 @@ class CenterCrop(object): ...@@ -182,6 +202,9 @@ class CenterCrop(object):
""" """
return F.center_crop(img, self.size) return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class Pad(object): class Pad(object):
"""Pad the given PIL Image on all sides with the given "pad" value. """Pad the given PIL Image on all sides with the given "pad" value.
...@@ -216,6 +239,9 @@ class Pad(object): ...@@ -216,6 +239,9 @@ class Pad(object):
""" """
return F.pad(img, self.padding, self.fill) return F.pad(img, self.padding, self.fill)
def __repr__(self):
return self.__class__.__name__ + '(padding={0})'.format(self.padding)
class Lambda(object): class Lambda(object):
"""Apply a user-defined lambda as a transform. """Apply a user-defined lambda as a transform.
...@@ -231,6 +257,9 @@ class Lambda(object): ...@@ -231,6 +257,9 @@ class Lambda(object):
def __call__(self, img): def __call__(self, img):
return self.lambd(img) return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomCrop(object): class RandomCrop(object):
"""Crop the given PIL Image at a random location. """Crop the given PIL Image at a random location.
...@@ -287,6 +316,9 @@ class RandomCrop(object): ...@@ -287,6 +316,9 @@ class RandomCrop(object):
return F.crop(img, i, j, h, w) return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a probability of 0.5.""" """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
...@@ -303,6 +335,9 @@ class RandomHorizontalFlip(object): ...@@ -303,6 +335,9 @@ class RandomHorizontalFlip(object):
return F.hflip(img) return F.hflip(img)
return img return img
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomVerticalFlip(object): class RandomVerticalFlip(object):
"""Vertically flip the given PIL Image randomly with a probability of 0.5.""" """Vertically flip the given PIL Image randomly with a probability of 0.5."""
...@@ -319,6 +354,9 @@ class RandomVerticalFlip(object): ...@@ -319,6 +354,9 @@ class RandomVerticalFlip(object):
return F.vflip(img) return F.vflip(img)
return img return img
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomResizedCrop(object): class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio. """Crop the given PIL Image to random size and aspect ratio.
...@@ -387,6 +425,9 @@ class RandomResizedCrop(object): ...@@ -387,6 +425,9 @@ class RandomResizedCrop(object):
i, j, h, w = self.get_params(img, self.scale, self.ratio) i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomSizedCrop(RandomResizedCrop): class RandomSizedCrop(RandomResizedCrop):
""" """
...@@ -433,6 +474,9 @@ class FiveCrop(object): ...@@ -433,6 +474,9 @@ class FiveCrop(object):
def __call__(self, img): def __call__(self, img):
return F.five_crop(img, self.size) return F.five_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class TenCrop(object): class TenCrop(object):
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of """Crop the given PIL Image into four corners and the central crop plus the flipped version of
...@@ -473,6 +517,9 @@ class TenCrop(object): ...@@ -473,6 +517,9 @@ class TenCrop(object):
def __call__(self, img): def __call__(self, img):
return F.ten_crop(img, self.size, self.vertical_flip) return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class LinearTransformation(object): class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix computed """Transform a tensor image with a square transformation matrix computed
...@@ -514,6 +561,11 @@ class LinearTransformation(object): ...@@ -514,6 +561,11 @@ class LinearTransformation(object):
tensor = transformed_tensor.view(tensor.size()) tensor = transformed_tensor.view(tensor.size())
return tensor return tensor
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
return format_string
class ColorJitter(object): class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image. """Randomly change the brightness, contrast and saturation of an image.
...@@ -578,6 +630,9 @@ class ColorJitter(object): ...@@ -578,6 +630,9 @@ class ColorJitter(object):
self.saturation, self.hue) self.saturation, self.hue)
return transform(img) return transform(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomRotation(object): class RandomRotation(object):
"""Rotate the image by angle. """Rotate the image by angle.
...@@ -636,6 +691,9 @@ class RandomRotation(object): ...@@ -636,6 +691,9 @@ class RandomRotation(object):
return F.rotate(img, angle, self.resample, self.expand, self.center) return F.rotate(img, angle, self.resample, self.expand, self.center)
def __repr__(self):
return self.__class__.__name__ + '(degrees={0})'.format(self.degrees)
class Grayscale(object): class Grayscale(object):
"""Convert image to grayscale. """Convert image to grayscale.
...@@ -663,6 +721,9 @@ class Grayscale(object): ...@@ -663,6 +721,9 @@ class Grayscale(object):
""" """
return F.to_grayscale(img, num_output_channels=self.num_output_channels) return F.to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomGrayscale(object): class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of p (default 0.1). """Randomly convert image to grayscale with a probability of p (default 0.1).
...@@ -693,3 +754,6 @@ class RandomGrayscale(object): ...@@ -693,3 +754,6 @@ class RandomGrayscale(object):
if random.random() < self.p: if random.random() < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels) return F.to_grayscale(img, num_output_channels=num_output_channels)
return img return img
def __repr__(self):
return self.__class__.__name__ + '()'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment