from __future__ import division import torch import math import random from PIL import Image, ImageOps, ImageEnhance try: import accimage except ImportError: accimage = None import numpy as np import numbers import types import collections import warnings from . import functional as F __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter"] class Compose(object): """Composes several transforms together. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img): for t in self.transforms: img = t(img) return img class ToTensor(object): """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. """ def __call__(self, pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ return F.to_tensor(pic) class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape H x W x C to a PIL Image while preserving the value range. Args: mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). If ``mode`` is ``None`` (default) there are some assumptions made about the input data: 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, ``int``, ``float``, ``short``). .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes """ def __init__(self, mode=None): self.mode = mode def __call__(self, pic): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. Returns: PIL Image: Image converted to PIL Image. """ return F.to_pil_image(pic, self.mode) class Normalize(object): """Normalize an tensor image with mean and standard deviation. Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e. ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. """ def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized Tensor image. """ return F.normalize(tensor, self.mean, self.std) class Resize(object): """Resize the input PIL Image to the given size. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` """ def __init__(self, size, interpolation=Image.BILINEAR): assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) self.size = size self.interpolation = interpolation def __call__(self, img): """ Args: img (PIL Image): Image to be scaled. Returns: PIL Image: Rescaled image. """ return F.resize(img, self.size, self.interpolation) class Scale(Resize): """ Note: This transform is deprecated in favor of Resize. """ def __init__(self, *args, **kwargs): warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.") super(Scale, self).__init__(*args, **kwargs) class CenterCrop(object): """Crops the given PIL Image at the center. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. """ def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, img): """ Args: img (PIL Image): Image to be cropped. Returns: PIL Image: Cropped image. """ return F.center_crop(img, self.size) class Pad(object): """Pad the given PIL Image on all sides with the given "pad" value. Args: padding (int or tuple): Padding on each border. If a single int is provided this is used to pad all borders. If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. fill: Pixel fill value. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. """ def __init__(self, padding, fill=0): assert isinstance(padding, (numbers.Number, tuple)) assert isinstance(fill, (numbers.Number, str, tuple)) if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) self.padding = padding self.fill = fill def __call__(self, img): """ Args: img (PIL Image): Image to be padded. Returns: PIL Image: Padded image. """ return F.pad(img, self.padding, self.fill) class Lambda(object): """Apply a user-defined lambda as a transform. Args: lambd (function): Lambda/function to be used for transform. """ def __init__(self, lambd): assert isinstance(lambd, types.LambdaType) self.lambd = lambd def __call__(self, img): return self.lambd(img) class RandomCrop(object): """Crop the given PIL Image at a random location. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. padding (int or sequence, optional): Optional padding on each border of the image. Default is 0, i.e no padding. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. """ def __init__(self, size, padding=0): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size self.padding = padding @staticmethod def get_params(img, output_size): """Get parameters for ``crop`` for a random crop. Args: img (PIL Image): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ w, h = img.size th, tw = output_size if w == tw and h == th: return 0, 0, h, w i = random.randint(0, h - th) j = random.randint(0, w - tw) return i, j, th, tw def __call__(self, img): """ Args: img (PIL Image): Image to be cropped. Returns: PIL Image: Cropped image. """ if self.padding > 0: img = F.pad(img, self.padding) i, j, h, w = self.get_params(img, self.size) return F.crop(img, i, j, h, w) class RandomHorizontalFlip(object): """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" def __call__(self, img): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Randomly flipped image. """ if random.random() < 0.5: return F.hflip(img) return img class RandomVerticalFlip(object): """Vertically flip the given PIL Image randomly with a probability of 0.5.""" def __call__(self, img): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Randomly flipped image. """ if random.random() < 0.5: return F.vflip(img) return img class RandomResizedCrop(object): """Crop the given PIL Image to random size and aspect ratio. A crop of random size of (0.08 to 1.0) of the original size and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, interpolation=Image.BILINEAR): self.size = (size, size) self.interpolation = interpolation @staticmethod def get_params(img): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Image to be cropped. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ for attempt in range(10): area = img.size[0] * img.size[1] target_area = random.uniform(0.08, 1.0) * area aspect_ratio = random.uniform(3. / 4, 4. / 3) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if random.random() < 0.5: w, h = h, w if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w # Fallback w = min(img.size[0], img.size[1]) i = (img.size[1] - w) // 2 j = (img.size[0] - w) // 2 return i, j, w, w def __call__(self, img): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Randomly cropped and resize image. """ i, j, h, w = self.get_params(img) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) class RandomSizedCrop(RandomResizedCrop): """ Note: This transform is deprecated in favor of RandomResizedCrop. """ def __init__(self, *args, **kwargs): warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + "please use transforms.RandomResizedCrop instead.") super(RandomSizedCrop, self).__init__(*args, **kwargs) class FiveCrop(object): """Crop the given PIL Image into four corners and the central crop .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your Dataset returns. See below for an example of how to deal with this. Args: size (sequence or int): Desired output size of the crop. If size is an ``int`` instead of sequence like (h, w), a square crop of size (size, size) is made. Example: >>> transform = Compose([ >>> FiveCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops """ def __init__(self, size): self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size def __call__(self, img): return F.five_crop(img, self.size) class TenCrop(object): """Crop the given PIL Image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default) .. Note:: This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your Dataset returns. See below for an example of how to deal with this. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. vertical_flip(bool): Use vertical flipping instead of horizontal Example: >>> transform = Compose([ >>> TenCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops """ def __init__(self, size, vertical_flip=False): self.size = size if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size self.vertical_flip = vertical_flip def __call__(self, img): return F.ten_crop(img, self.size, self.vertical_flip) class LinearTransformation(object): """Transform a tensor image with a square transformation matrix computed offline. Given transformation_matrix, will flatten the torch.*Tensor, compute the dot product with the transformation matrix and reshape the tensor to its original shape. Applications: - whitening: zero-center the data, compute the data covariance matrix [D x D] with np.dot(X.T, X), perform SVD on this matrix and pass it as transformation_matrix. Args: transformation_matrix (Tensor): tensor [D x D], D = C x H x W """ def __init__(self, transformation_matrix): if transformation_matrix.size(0) != transformation_matrix.size(1): raise ValueError("transformation_matrix should be square. Got " + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) self.transformation_matrix = transformation_matrix def __call__(self, tensor): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be whitened. Returns: Tensor: Transformed image. """ if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): raise ValueError("tensor and transformation matrix have incompatible shape." + "[{} x {} x {}] != ".format(*tensor.size()) + "{}".format(self.transformation_matrix.size(0))) flat_tensor = tensor.view(1, -1) transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) tensor = transformed_tensor.view(tensor.size()) return tensor class ColorJitter(object): """Randomly change the brightness, contrast and saturation of an image. Args: brightness (float): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. contrast (float): How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. saturation (float): How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. hue(float): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue]. Should be >=0 and <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue @staticmethod def get_params(brightness, contrast, saturation, hue): """Get a randomized transform to be applied on image. Arguments are same as that of __init__. Returns: Transform which randomly adjusts brightness, contrast and saturation in a random order. """ transforms = [] if brightness > 0: brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) if contrast > 0: contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) if saturation > 0: saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) if hue > 0: hue_factor = np.random.uniform(-hue, hue) transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) np.random.shuffle(transforms) transform = Compose(transforms) return transform def __call__(self, img): """ Args: img (PIL Image): Input image. Returns: PIL Image: Color jittered image. """ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) return transform(img)