Commit 459dc59e authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #240 from chsasank/master

Refactor of transforms 
parents addfbd1d 2cc58ed0
......@@ -13,6 +13,259 @@ import types
import collections
def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
def to_tensor(pic):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details.
Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not(_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
return img.float().div(255)
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
def to_pil_image(pic):
"""Convert a tensor or an ndarray to PIL Image.
See ``ToPIlImage`` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.
Returns:
PIL.Image: Image converted to PIL.Image.
"""
if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray)
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
mode = 'L'
if npimg.dtype == np.int16:
mode = 'I;16'
if npimg.dtype == np.int32:
mode = 'I'
elif npimg.dtype == np.float32:
mode = 'F'
elif npimg.shape[2] == 4:
if npimg.dtype == np.uint8:
mode = 'RGBA'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)
def normalize(tensor, mean, std):
"""Normalize an tensor image with mean and standard deviation.
See ``Normalize`` for more details.
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for R, G, B channels respecitvely.
std (sequence): Sequence of standard deviations for R, G, B channels
respecitvely.
Returns:
Tensor: Normalized image.
"""
if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')
# TODO: make efficient
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def scale(img, size, interpolation=Image.BILINEAR):
"""Rescale the input PIL.Image to the given size.
Args:
img (PIL.Image): Image to be scaled.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
Returns:
PIL.Image: Rescaled image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
def pad(img, padding, fill=0):
"""Pad the given PIL.Image on all sides with the given "pad" value.
Args:
img (PIL.Image): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill: Pixel fill value. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
Returns:
PIL.Image: Padded image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(padding, (numbers.Number, tuple)):
raise TypeError('Got inappropriate padding arg')
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError('Got inappropriate fill arg')
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
return ImageOps.expand(img, border=padding, fill=fill)
def crop(img, i, j, h, w):
"""Crop the given PIL.Image.
Args:
img (PIL.Image): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
Returns:
PIL.Image: Cropped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.crop((j, i, j + w, i + h))
def scaled_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
"""Crop the given PIL.Image and scale it to desired size.
Notably used in RandomSizedCrop.
Args:
img (PIL.Image): Image to be cropped.
i: Upper pixel coordinate.
j: Left pixel coordinate.
h: Height of the cropped image.
w: Width of the cropped image.
size (sequence or int): Desired output size. Same semantics as ``scale``.
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``.
Returns:
PIL.Image: Cropped image.
"""
assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, i, j, h, w)
img = scale(img, size, interpolation)
return img
def hflip(img):
"""Horizontally flip the given PIL.Image.
Args:
img (PIL.Image): Image to be flipped.
Returns:
PIL.Image: Horizontall flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
class Compose(object):
"""Composes several transforms together.
......@@ -50,43 +303,11 @@ class ToTensor(object):
Returns:
Tensor: Converted image.
"""
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
return img.float().div(255)
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float().div(255)
else:
return img
return to_tensor(pic)
class ToPILImage(object):
"""Convert a tensor to PIL Image.
"""Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range.
......@@ -101,32 +322,7 @@ class ToPILImage(object):
PIL.Image: Image converted to PIL.Image.
"""
npimg = pic
mode = None
if isinstance(pic, torch.FloatTensor):
pic = pic.mul(255).byte()
if torch.is_tensor(pic):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
mode = 'L'
if npimg.dtype == np.int16:
mode = 'I;16'
if npimg.dtype == np.int32:
mode = 'I'
elif npimg.dtype == np.float32:
mode = 'F'
elif npimg.shape[2] == 4:
if npimg.dtype == np.uint8:
mode = 'RGBA'
else:
if npimg.dtype == np.uint8:
mode = 'RGB'
assert mode is not None, '{} is not supported'.format(npimg.dtype)
return Image.fromarray(npimg, mode=mode)
return to_pil_image(pic)
class Normalize(object):
......@@ -154,10 +350,7 @@ class Normalize(object):
Returns:
Tensor: Normalized image.
"""
# TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
return tensor
return normalize(tensor, self.mean, self.std)
class Scale(object):
......@@ -186,20 +379,7 @@ class Scale(object):
Returns:
PIL.Image: Rescaled image.
"""
if isinstance(self.size, int):
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
ow = self.size
oh = int(self.size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
oh = self.size
ow = int(self.size * w / h)
return img.resize((ow, oh), self.interpolation)
else:
return img.resize(self.size[::-1], self.interpolation)
return scale(img, self.size, self.interpolation)
class CenterCrop(object):
......@@ -217,6 +397,23 @@ class CenterCrop(object):
else:
self.size = size
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for center crop.
Args:
img (PIL.Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
"""
w, h = img.size
th, tw = output_size
i = int(round((h - th) / 2.))
j = int(round((w - tw) / 2.))
return i, j, th, tw
def __call__(self, img):
"""
Args:
......@@ -225,11 +422,8 @@ class CenterCrop(object):
Returns:
PIL.Image: Cropped image.
"""
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
i, j, h, w = self.get_params(img, self.size)
return crop(img, i, j, h, w)
class Pad(object):
......@@ -263,7 +457,7 @@ class Pad(object):
Returns:
PIL.Image: Padded image.
"""
return ImageOps.expand(img, border=self.padding, fill=self.fill)
return pad(img, self.padding, self.fill)
class Lambda(object):
......@@ -301,6 +495,26 @@ class RandomCrop(object):
self.size = size
self.padding = padding
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL.Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return img
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
"""
Args:
......@@ -310,16 +524,11 @@ class RandomCrop(object):
PIL.Image: Cropped image.
"""
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
img = pad(img, self.padding)
w, h = img.size
th, tw = self.size
if w == tw and h == th:
return img
i, j, h, w = self.get_params(img, self.size)
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return img.crop((x1, y1, x1 + tw, y1 + th))
return crop(img, i, j, h, w)
class RandomHorizontalFlip(object):
......@@ -334,7 +543,7 @@ class RandomHorizontalFlip(object):
PIL.Image: Randomly flipped image.
"""
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return hflip(img)
return img
......@@ -347,15 +556,25 @@ class RandomSizedCrop(object):
This is popularly used to train the Inception networks.
Args:
size: size of the smaller edge
size: expected output size of each edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.size = (size, size)
self.interpolation = interpolation
def __call__(self, img):
@staticmethod
def get_params(img):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL.Image): Image to be cropped.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
......@@ -368,15 +587,23 @@ class RandomSizedCrop(object):
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
# Fallback
w = min(img.size[0], img.shape[1])
i = (img.shape[1] - w) // 2
j = (img.shape[0] - w) // 2
return i, j, w, w
return img.resize((self.size, self.size), self.interpolation)
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be flipped.
# Fallback
scale = Scale(self.size, interpolation=self.interpolation)
crop = CenterCrop(self.size)
return crop(scale(img))
Returns:
PIL.Image: Randomly cropped and scaled image.
"""
i, j, h, w = self.get_params(img)
return scaled_crop(img, i, j, h, w, self.size, self.interpolation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment