from __future__ import division import torch import math import random from PIL import Image, ImageOps try: import accimage except ImportError: accimage = None import numpy as np import numbers import types import collections def _is_pil_image(img): if accimage is not None: return isinstance(img, (Image.Image, accimage.Image)) else: return isinstance(img, Image.Image) def _is_tensor_image(img): return torch.is_tensor(img) and img.ndimension() == 3 def _is_numpy_image(img): return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) def to_tensor(pic): """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. See ``ToTensor`` for more details. Args: pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ if not(_is_pil_image(pic) or _is_numpy_image(pic)): raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) if isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic.transpose((2, 0, 1))) # backward compatibility return img.float().div(255) if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) pic.copyto(nppic) return torch.from_numpy(nppic) # handle PIL Image if pic.mode == 'I': img = torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == 'I;16': img = torch.from_numpy(np.array(pic, np.int16, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16': nchannel = 1 else: nchannel = len(pic.mode) img = img.view(pic.size[1], pic.size[0], nchannel) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() if isinstance(img, torch.ByteTensor): return img.float().div(255) else: return img def to_pil_image(pic): """Convert a tensor or an ndarray to PIL Image. See ``ToPIlImage`` for more details. Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. Returns: PIL.Image: Image converted to PIL.Image. """ if not(_is_numpy_image(pic) or _is_tensor_image(pic)): raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) npimg = pic mode = None if isinstance(pic, torch.FloatTensor): pic = pic.mul(255).byte() if torch.is_tensor(pic): npimg = np.transpose(pic.numpy(), (1, 2, 0)) assert isinstance(npimg, np.ndarray) if npimg.shape[2] == 1: npimg = npimg[:, :, 0] if npimg.dtype == np.uint8: mode = 'L' if npimg.dtype == np.int16: mode = 'I;16' if npimg.dtype == np.int32: mode = 'I' elif npimg.dtype == np.float32: mode = 'F' elif npimg.shape[2] == 4: if npimg.dtype == np.uint8: mode = 'RGBA' else: if npimg.dtype == np.uint8: mode = 'RGB' assert mode is not None, '{} is not supported'.format(npimg.dtype) return Image.fromarray(npimg, mode=mode) def normalize(tensor, mean, std): """Normalize a tensor image with mean and standard deviation. See ``Normalize`` for more details. Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. mean (sequence): Sequence of means for R, G, B channels respecitvely. std (sequence): Sequence of standard deviations for R, G, B channels respecitvely. Returns: Tensor: Normalized image. """ if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') # TODO: make efficient for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensor def scale(img, size, interpolation=Image.BILINEAR): """Rescale the input PIL.Image to the given size. Args: img (PIL.Image): Image to be scaled. size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` Returns: PIL.Image: Rescaled image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) if isinstance(size, int): w, h = img.size if (w <= h and w == size) or (h <= w and h == size): return img if w < h: ow = size oh = int(size * h / w) return img.resize((ow, oh), interpolation) else: oh = size ow = int(size * w / h) return img.resize((ow, oh), interpolation) else: return img.resize(size[::-1], interpolation) def pad(img, padding, fill=0): """Pad the given PIL.Image on all sides with the given "pad" value. Args: img (PIL.Image): Image to be padded. padding (int or tuple): Padding on each border. If a single int is provided this is used to pad all borders. If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. fill: Pixel fill value. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. Returns: PIL.Image: Padded image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not isinstance(padding, (numbers.Number, tuple)): raise TypeError('Got inappropriate padding arg') if not isinstance(fill, (numbers.Number, str, tuple)): raise TypeError('Got inappropriate fill arg') if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) return ImageOps.expand(img, border=padding, fill=fill) def crop(img, i, j, h, w): """Crop the given PIL.Image. Args: img (PIL.Image): Image to be cropped. i: Upper pixel coordinate. j: Left pixel coordinate. h: Height of the cropped image. w: Width of the cropped image. Returns: PIL.Image: Cropped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.crop((j, i, j + w, i + h)) def scaled_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): """Crop the given PIL.Image and scale it to desired size. Notably used in RandomSizedCrop. Args: img (PIL.Image): Image to be cropped. i: Upper pixel coordinate. j: Left pixel coordinate. h: Height of the cropped image. w: Width of the cropped image. size (sequence or int): Desired output size. Same semantics as ``scale``. interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. Returns: PIL.Image: Cropped image. """ assert _is_pil_image(img), 'img should be PIL Image' img = crop(img, i, j, h, w) img = scale(img, size, interpolation) return img def hflip(img): """Horizontally flip the given PIL.Image. Args: img (PIL.Image): Image to be flipped. Returns: PIL.Image: Horizontall flipped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.transpose(Image.FLIP_LEFT_RIGHT) def vflip(img): """Vertically flip the given PIL.Image. Args: img (PIL.Image): Image to be flipped. Returns: PIL.Image: Vertically flipped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.transpose(Image.FLIP_TOP_BOTTOM) def five_crop(img, size): """Crop the given PIL.Image into four corners and the central crop. Note: this transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your `Dataset` returns. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. Returns: tuple: tuple (tl, tr, bl, br, center) corresponding top left, top right, bottom left, bottom right and center crop. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." w, h = img.size crop_h, crop_w = size if crop_w > w or crop_h > h: raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w))) tl = img.crop((0, 0, crop_w, crop_h)) tr = img.crop((w - crop_w, 0, w, crop_h)) bl = img.crop((0, h - crop_h, crop_w, h)) br = img.crop((w - crop_w, h - crop_h, w, h)) center = CenterCrop((crop_h, crop_w))(img) return (tl, tr, bl, br, center) def ten_crop(img, size, vertical_flip=False): """Crop the given PIL.Image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). Note: this transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your `Dataset` returns. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. vertical_flip (bool): Use vertical flipping instead of horizontal Returns: tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) corresponding top left, top right, bottom left, bottom right and center crop and same for the flipped image. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." first_five = five_crop(img, size) if vertical_flip: img = vflip(img) else: img = hflip(img) second_five = five_crop(img, size) return first_five + second_five class Compose(object): """Composes several transforms together. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img): for t in self.transforms: img = t(img) return img class ToTensor(object): """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. Converts a PIL.Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. """ def __call__(self, pic): """ Args: pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return to_tensor(pic) class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape H x W x C to a PIL.Image while preserving the value range. """ def __call__(self, pic): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. Returns: PIL.Image: Image converted to PIL.Image. """ return to_pil_image(pic) class Normalize(object): """Normalize an tensor image with mean and standard deviation. Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the torch.*Tensor, i.e. channel = (channel - mean) / std Args: mean (sequence): Sequence of means for R, G, B channels respecitvely. std (sequence): Sequence of standard deviations for R, G, B channels respecitvely. """ def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized image. """ return normalize(tensor, self.mean, self.std) class Scale(object): """Rescale the input PIL.Image to the given size. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, size, interpolation=Image.BILINEAR): assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) self.size = size self.interpolation = interpolation def __call__(self, img): """ Args: img (PIL.Image): Image to be scaled. Returns: PIL.Image: Rescaled image. """ return scale(img, self.size, self.interpolation) class CenterCrop(object): """Crops the given PIL.Image at the center. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. """ def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size @staticmethod def get_params(img, output_size): """Get parameters for ``crop`` for center crop. Args: img (PIL.Image): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. """ w, h = img.size th, tw = output_size i = int(round((h - th) / 2.)) j = int(round((w - tw) / 2.)) return i, j, th, tw def __call__(self, img): """ Args: img (PIL.Image): Image to be cropped. Returns: PIL.Image: Cropped image. """ i, j, h, w = self.get_params(img, self.size) return crop(img, i, j, h, w) class Pad(object): """Pad the given PIL.Image on all sides with the given "pad" value. Args: padding (int or tuple): Padding on each border. If a single int is provided this is used to pad all borders. If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. fill: Pixel fill value. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. """ def __init__(self, padding, fill=0): assert isinstance(padding, (numbers.Number, tuple)) assert isinstance(fill, (numbers.Number, str, tuple)) if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) self.padding = padding self.fill = fill def __call__(self, img): """ Args: img (PIL.Image): Image to be padded. Returns: PIL.Image: Padded image. """ return pad(img, self.padding, self.fill) class Lambda(object): """Apply a user-defined lambda as a transform. Args: lambd (function): Lambda/function to be used for transform. """ def __init__(self, lambd): assert isinstance(lambd, types.LambdaType) self.lambd = lambd def __call__(self, img): return self.lambd(img) class RandomCrop(object): """Crop the given PIL.Image at a random location. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. padding (int or sequence, optional): Optional padding on each border of the image. Default is 0, i.e no padding. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. """ def __init__(self, size, padding=0): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.padding = padding @staticmethod def get_params(img, output_size): """Get parameters for ``crop`` for a random crop. Args: img (PIL.Image): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ w, h = img.size th, tw = output_size if w == tw and h == th: return img i = random.randint(0, h - th) j = random.randint(0, w - tw) return i, j, th, tw def __call__(self, img): """ Args: img (PIL.Image): Image to be cropped. Returns: PIL.Image: Cropped image. """ if self.padding > 0: img = pad(img, self.padding) i, j, h, w = self.get_params(img, self.size) return crop(img, i, j, h, w) class RandomHorizontalFlip(object): """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" def __call__(self, img): """ Args: img (PIL.Image): Image to be flipped. Returns: PIL.Image: Randomly flipped image. """ if random.random() < 0.5: return hflip(img) return img class RandomVerticalFlip(object): """Vertically flip the given PIL.Image randomly with a probability of 0.5.""" def __call__(self, img): """ Args: img (PIL.Image): Image to be flipped. Returns: PIL.Image: Randomly flipped image. """ if random.random() < 0.5: return vflip(img) return img class RandomSizedCrop(object): """Crop the given PIL.Image to random size and aspect ratio. A crop of random size of (0.08 to 1.0) of the original size and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, interpolation=Image.BILINEAR): self.size = (size, size) self.interpolation = interpolation @staticmethod def get_params(img): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL.Image): Image to be cropped. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.08, 1.0) * area aspect_ratio = random.uniform(3. / 4, 4. / 3) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if random.random() < 0.5: w, h = h, w if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w # Fallback w = min(img.size[0], img.shape[1]) i = (img.shape[1] - w) // 2 j = (img.shape[0] - w) // 2 return i, j, w, w def __call__(self, img): """ Args: img (PIL.Image): Image to be flipped. Returns: PIL.Image: Randomly cropped and scaled image. """ i, j, h, w = self.get_params(img) return scaled_crop(img, i, j, h, w, self.size, self.interpolation) class FiveCrop(object): """Crop the given PIL.Image into four corners and the central crop.abs Note: this transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your `Dataset` returns. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. """ def __init__(self, size): self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size def __call__(self, img): return five_crop(img, self.size) class TenCrop(object): """Crop the given PIL.Image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default) Note: this transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your `Dataset` returns. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. vertical_flip(bool): Use vertical flipping instead of horizontal """ def __init__(self, size, vertical_flip=False): self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size self.vertical_flip = vertical_flip def __call__(self, img): return ten_crop(img, self.size, self.vertical_flip)