Commit f1b79d90 authored by Vishwak Srinivasan's avatar Vishwak Srinivasan Committed by Soumith Chintala
Browse files

Update transforms __repr__ (#419)

* Update transforms __repr__

* Add _pil_interpolation_to_str map for obtaining strings for PIL.Image interpolation options

* Fix typo
parent 75003739
...@@ -20,6 +20,13 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", ...@@ -20,6 +20,13 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"] "ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -103,7 +110,11 @@ class ToPILImage(object): ...@@ -103,7 +110,11 @@ class ToPILImage(object):
return F.to_pil_image(pic, self.mode) return F.to_pil_image(pic, self.mode)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '({0})'.format(self.mode) format_string = self.__class__.__name__ + '('
if self.mode is not None:
format_string += 'mode={0}'.format(self.mode)
format_string += ')'
return format_string
class Normalize(object): class Normalize(object):
...@@ -164,7 +175,8 @@ class Resize(object): ...@@ -164,7 +175,8 @@ class Resize(object):
return F.resize(img, self.size, self.interpolation) return F.resize(img, self.size, self.interpolation)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
class Scale(Resize): class Scale(Resize):
...@@ -240,7 +252,7 @@ class Pad(object): ...@@ -240,7 +252,7 @@ class Pad(object):
return F.pad(img, self.padding, self.fill) return F.pad(img, self.padding, self.fill)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(padding={0})'.format(self.padding) return self.__class__.__name__ + '(padding={0}, fill={1})'.format(self.padding, self.fill)
class Lambda(object): class Lambda(object):
...@@ -388,7 +400,7 @@ class RandomCrop(object): ...@@ -388,7 +400,7 @@ class RandomCrop(object):
return F.crop(img, i, j, h, w) return F.crop(img, i, j, h, w)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
...@@ -511,7 +523,12 @@ class RandomResizedCrop(object): ...@@ -511,7 +523,12 @@ class RandomResizedCrop(object):
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(round(self.scale, 4))
format_string += ', ratio={0}'.format(round(self.ratio, 4))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
class RandomSizedCrop(RandomResizedCrop): class RandomSizedCrop(RandomResizedCrop):
...@@ -603,7 +620,7 @@ class TenCrop(object): ...@@ -603,7 +620,7 @@ class TenCrop(object):
return F.ten_crop(img, self.size, self.vertical_flip) return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class LinearTransformation(object): class LinearTransformation(object):
...@@ -716,7 +733,12 @@ class ColorJitter(object): ...@@ -716,7 +733,12 @@ class ColorJitter(object):
return transform(img) return transform(img)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
return format_string
class RandomRotation(object): class RandomRotation(object):
...@@ -777,7 +799,13 @@ class RandomRotation(object): ...@@ -777,7 +799,13 @@ class RandomRotation(object):
return F.rotate(img, angle, self.resample, self.expand, self.center) return F.rotate(img, angle, self.resample, self.expand, self.center)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(degrees={0})'.format(self.degrees) format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', resample={0}'.format(self.resample)
format_string += ', expand={0}'.format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
format_string += ')'
return format_string
class Grayscale(object): class Grayscale(object):
...@@ -807,7 +835,7 @@ class Grayscale(object): ...@@ -807,7 +835,7 @@ class Grayscale(object):
return F.to_grayscale(img, num_output_channels=self.num_output_channels) return F.to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
class RandomGrayscale(object): class RandomGrayscale(object):
...@@ -841,4 +869,4 @@ class RandomGrayscale(object): ...@@ -841,4 +869,4 @@ class RandomGrayscale(object):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + '(p={0})'.format(self.p)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment