Commit 2b2aa9c7 authored by Vishwak Srinivasan's avatar Vishwak Srinivasan Committed by Soumith Chintala
Browse files

Add descriptions for Transform objects (#380)

* add __repr__ for all transform objects

* add tests for __repr__
parent 9889de1d
......@@ -113,6 +113,11 @@ class Tester(unittest.TestCase):
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
results = transform(img)
expected_output = five_crop(img)
# Checking if FiveCrop and TenCrop can be printed as string
transform.__repr__()
five_crop.__repr__()
if should_vflip:
vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
expected_output += five_crop(vflipped_img)
......@@ -226,6 +231,9 @@ class Tester(unittest.TestCase):
assert output.size[0] == width + padding[0] + padding[2]
assert output.size[1] == height + padding[1] + padding[3]
# Checking if Padding can be printed as string
transforms.Pad(padding).__repr__()
def test_pad_raises_with_invalid_pad_sequence_len(self):
with self.assertRaises(ValueError):
transforms.Pad(())
......@@ -247,6 +255,9 @@ class Tester(unittest.TestCase):
y = trans(x)
assert (y.equal(x))
# Checking if Lambda can be printed as string
trans.__repr__()
def test_to_tensor(self):
test_channels = [1, 3, 4]
height, width = 4, 4
......@@ -280,6 +291,9 @@ class Tester(unittest.TestCase):
transforms.ToTensor(),
])
# Checking if Compose, Resize and ToTensor can be printed as string
trans.__repr__()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
......@@ -296,6 +310,9 @@ class Tester(unittest.TestCase):
transforms.ToTensor(),
])
# Checking if Compose, CenterCrop and ToTensor can be printed as string
trans.__repr__()
expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))
......@@ -375,6 +392,9 @@ class Tester(unittest.TestCase):
for mode in [None, 'RGB', 'HSV', 'YCbCr']:
verify_img_data(img_data, mode)
# Checking if ToPILImage can be printed as string
transforms.ToPILImage().__repr__()
with self.assertRaises(ValueError):
# should raise if we try a mode for 4 or 1 channel images
transforms.ToPILImage(mode='RGBA')(img_data)
......@@ -450,6 +470,9 @@ class Tester(unittest.TestCase):
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomVerticalFlip can be printed as string
transforms.RandomVerticalFlip().__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_horizontal_flip(self):
random_state = random.getstate()
......@@ -468,6 +491,9 @@ class Tester(unittest.TestCase):
random.setstate(random_state)
assert p_value > 0.0001
# Checking if RandomHorizontalFlip can be printed as string
transforms.RandomHorizontalFlip().__repr__()
@unittest.skipIf(stats is None, 'scipt.stats is not available')
def test_normalize(self):
def samples_from_standard_normal(tensor):
......@@ -484,6 +510,9 @@ class Tester(unittest.TestCase):
assert samples_from_standard_normal(normalized)
random.setstate(random_state)
# Checking if Normalize can be printed as string
transforms.Normalize(mean, std).__repr__()
def test_adjust_brightness(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
......@@ -645,6 +674,9 @@ class Tester(unittest.TestCase):
y_pil_2 = color_jitter(x_pil_2)
assert y_pil_2.mode == x_pil_2.mode
# Checking if ColorJitter can be printed as string
color_jitter.__repr__()
def test_linear_transformation(self):
x = torch.randn(250, 10, 10, 3)
flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
......@@ -664,6 +696,9 @@ class Tester(unittest.TestCase):
cov = np.dot(xwhite, xwhite.T) / x.size(0)
assert np.allclose(cov, np.identity(1), rtol=1e-3)
# Checking if LinearTransformation can be printed as string
whitening.__repr__()
def test_rotate(self):
x = np.zeros((100, 100, 3), dtype=np.uint8)
x[40, 40] = [255, 255, 255]
......@@ -714,6 +749,9 @@ class Tester(unittest.TestCase):
angle = t.get_params(t.degrees)
assert angle > -10 and angle < 10
# Checking if RandomRotation can be printed as string
t.__repr__()
def test_to_grayscale(self):
"""Unit tests for grayscale transform"""
......@@ -761,6 +799,9 @@ class Tester(unittest.TestCase):
np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2])
np.testing.assert_equal(gray_np, gray_np_4[:, :, 0])
# Checking if Grayscale can be printed as string
trans4.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_grayscale(self):
"""Unit tests for random grayscale transform"""
......@@ -851,6 +892,9 @@ class Tester(unittest.TestCase):
assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel'
np.testing.assert_equal(gray_np, gray_np_3)
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()
if __name__ == '__main__':
unittest.main()
......@@ -42,6 +42,14 @@ class Compose(object):
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
......@@ -60,6 +68,9 @@ class ToTensor(object):
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.
......@@ -91,6 +102,9 @@ class ToPILImage(object):
"""
return F.to_pil_image(pic, self.mode)
def __repr__(self):
return self.__class__.__name__ + '({0})'.format(self.mode)
class Normalize(object):
"""Normalize an tensor image with mean and standard deviation.
......@@ -117,6 +131,9 @@ class Normalize(object):
"""
return F.normalize(tensor, self.mean, self.std)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Resize(object):
"""Resize the input PIL Image to the given size.
......@@ -146,6 +163,9 @@ class Resize(object):
"""
return F.resize(img, self.size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class Scale(Resize):
"""
......@@ -182,6 +202,9 @@ class CenterCrop(object):
"""
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class Pad(object):
"""Pad the given PIL Image on all sides with the given "pad" value.
......@@ -216,6 +239,9 @@ class Pad(object):
"""
return F.pad(img, self.padding, self.fill)
def __repr__(self):
return self.__class__.__name__ + '(padding={0})'.format(self.padding)
class Lambda(object):
"""Apply a user-defined lambda as a transform.
......@@ -231,6 +257,9 @@ class Lambda(object):
def __call__(self, img):
return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomCrop(object):
"""Crop the given PIL Image at a random location.
......@@ -287,6 +316,9 @@ class RandomCrop(object):
return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a probability of 0.5."""
......@@ -303,6 +335,9 @@ class RandomHorizontalFlip(object):
return F.hflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomVerticalFlip(object):
"""Vertically flip the given PIL Image randomly with a probability of 0.5."""
......@@ -319,6 +354,9 @@ class RandomVerticalFlip(object):
return F.vflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
......@@ -387,6 +425,9 @@ class RandomResizedCrop(object):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomSizedCrop(RandomResizedCrop):
"""
......@@ -433,6 +474,9 @@ class FiveCrop(object):
def __call__(self, img):
return F.five_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class TenCrop(object):
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of
......@@ -473,6 +517,9 @@ class TenCrop(object):
def __call__(self, img):
return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class LinearTransformation(object):
"""Transform a tensor image with a square transformation matrix computed
......@@ -514,6 +561,11 @@ class LinearTransformation(object):
tensor = transformed_tensor.view(tensor.size())
return tensor
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += (str(self.transformation_matrix.numpy().tolist()) + ')')
return format_string
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
......@@ -578,6 +630,9 @@ class ColorJitter(object):
self.saturation, self.hue)
return transform(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomRotation(object):
"""Rotate the image by angle.
......@@ -636,6 +691,9 @@ class RandomRotation(object):
return F.rotate(img, angle, self.resample, self.expand, self.center)
def __repr__(self):
return self.__class__.__name__ + '(degrees={0})'.format(self.degrees)
class Grayscale(object):
"""Convert image to grayscale.
......@@ -663,6 +721,9 @@ class Grayscale(object):
"""
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
......@@ -693,3 +754,6 @@ class RandomGrayscale(object):
if random.random() < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels)
return img
def __repr__(self):
return self.__class__.__name__ + '()'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment