"src/vscode:/vscode.git/clone" did not exist on "5ded26cdc70387e7d08144275c138dc0c175a4a2"
Commit f1b79d90 authored by Vishwak Srinivasan's avatar Vishwak Srinivasan Committed by Soumith Chintala
Browse files

Update transforms __repr__ (#419)

* Update transforms __repr__

* Add _pil_interpolation_to_str map for obtaining strings for PIL.Image interpolation options

* Fix typo
parent 75003739
......@@ -20,6 +20,13 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
}
class Compose(object):
"""Composes several transforms together.
......@@ -103,7 +110,11 @@ class ToPILImage(object):
return F.to_pil_image(pic, self.mode)
def __repr__(self):
return self.__class__.__name__ + '({0})'.format(self.mode)
format_string = self.__class__.__name__ + '('
if self.mode is not None:
format_string += 'mode={0}'.format(self.mode)
format_string += ')'
return format_string
class Normalize(object):
......@@ -164,7 +175,8 @@ class Resize(object):
return F.resize(img, self.size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
class Scale(Resize):
......@@ -240,7 +252,7 @@ class Pad(object):
return F.pad(img, self.padding, self.fill)
def __repr__(self):
return self.__class__.__name__ + '(padding={0})'.format(self.padding)
return self.__class__.__name__ + '(padding={0}, fill={1})'.format(self.padding, self.fill)
class Lambda(object):
......@@ -388,7 +400,7 @@ class RandomCrop(object):
return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class RandomHorizontalFlip(object):
......@@ -511,7 +523,12 @@ class RandomResizedCrop(object):
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(round(self.scale, 4))
format_string += ', ratio={0}'.format(round(self.ratio, 4))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
class RandomSizedCrop(RandomResizedCrop):
......@@ -603,7 +620,7 @@ class TenCrop(object):
return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
class LinearTransformation(object):
......@@ -716,7 +733,12 @@ class ColorJitter(object):
return transform(img)
def __repr__(self):
return self.__class__.__name__ + '()'
format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
return format_string
class RandomRotation(object):
......@@ -777,7 +799,13 @@ class RandomRotation(object):
return F.rotate(img, angle, self.resample, self.expand, self.center)
def __repr__(self):
return self.__class__.__name__ + '(degrees={0})'.format(self.degrees)
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', resample={0}'.format(self.resample)
format_string += ', expand={0}'.format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
format_string += ')'
return format_string
class Grayscale(object):
......@@ -807,7 +835,7 @@ class Grayscale(object):
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '()'
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
class RandomGrayscale(object):
......@@ -841,4 +869,4 @@ class RandomGrayscale(object):
return img
def __repr__(self):
return self.__class__.__name__ + '()'
return self.__class__.__name__ + '(p={0})'.format(self.p)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment