import torch from torch import Tensor from torch.jit.annotations import List, BroadcastingList2 def _is_tensor_a_torch_image(input): return input.ndim >= 2 def vflip(img): # type: (Tensor) -> Tensor """Vertically flip the given the Image Tensor. Args: img (Tensor): Image Tensor to be flipped in the form [C, H, W]. Returns: Tensor: Vertically flipped image Tensor. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') return img.flip(-2) def hflip(img): # type: (Tensor) -> Tensor """Horizontally flip the given the Image Tensor. Args: img (Tensor): Image Tensor to be flipped in the form [C, H, W]. Returns: Tensor: Horizontally flipped image Tensor. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') return img.flip(-1) def crop(img, top, left, height, width): # type: (Tensor, int, int, int, int) -> Tensor """Crop the given Image Tensor. Args: img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. Returns: Tensor: Cropped image. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') return img[..., top:top + height, left:left + width] def rgb_to_grayscale(img): # type: (Tensor) -> Tensor """Convert the given RGB Image Tensor to Grayscale. For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which is L = R * 0.2989 + G * 0.5870 + B * 0.1140 Args: img (Tensor): Image to be converted to Grayscale in the form [C, H, W]. Returns: Tensor: Grayscale image. """ if img.shape[0] != 3: raise TypeError('Input Image does not contain 3 Channels') return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) def adjust_brightness(img, brightness_factor): # type: (Tensor, float) -> Tensor """Adjust brightness of an RGB image. Args: img (Tensor): Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image while 2 increases the brightness by a factor of 2. Returns: Tensor: Brightness adjusted image. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') return _blend(img, torch.zeros_like(img), brightness_factor) def adjust_contrast(img, contrast_factor): # type: (Tensor, float) -> Tensor """Adjust contrast of an RGB image. Args: img (Tensor): Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image while 2 increases the contrast by a factor of 2. Returns: Tensor: Contrast adjusted image. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') mean = torch.mean(rgb_to_grayscale(img).to(torch.float)) return _blend(img, mean, contrast_factor) def adjust_hue(img, hue_factor): """Adjust hue of an image. The image hue is adjusted by converting the image to HSV and cyclically shifting the intensities in the hue channel (H). The image is then converted back to original image mode. `hue_factor` is the amount of shift in H channel and must be in the interval `[-0.5, 0.5]`. See `Hue`_ for more details. .. _Hue: https://en.wikipedia.org/wiki/Hue Args: img (Tensor): Image to be adjusted. Image type is either uint8 or float. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image with complementary colors while 0 gives the original image. Returns: Tensor: Hue adjusted image. """ if not(-0.5 <= hue_factor <= 0.5): raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') orig_dtype = img.dtype if img.dtype == torch.uint8: img = img.to(dtype=torch.float32) / 255.0 img = _rgb2hsv(img) h, s, v = img.unbind(0) h += hue_factor h = h % 1.0 img = torch.stack((h, s, v)) img_hue_adj = _hsv2rgb(img) if orig_dtype == torch.uint8: img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype) return img_hue_adj def adjust_saturation(img, saturation_factor): # type: (Tensor, float) -> Tensor """Adjust color saturation of an RGB image. Args: img (Tensor): Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: Tensor: Saturation adjusted image. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') return _blend(img, rgb_to_grayscale(img), saturation_factor) def center_crop(img, output_size): # type: (Tensor, BroadcastingList2[int]) -> Tensor """Crop the Image Tensor and resize it to desired size. Args: img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. output_size (sequence or int): (height, width) of the crop box. If int, it is used for both directions Returns: Tensor: Cropped image. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') _, image_width, image_height = img.size() crop_height, crop_width = output_size crop_top = int(round((image_height - crop_height) / 2.)) crop_left = int(round((image_width - crop_width) / 2.)) return crop(img, crop_top, crop_left, crop_height, crop_width) def five_crop(img, size): # type: (Tensor, BroadcastingList2[int]) -> List[Tensor] """Crop the given Image Tensor into four corners and the central crop. .. Note:: This transform returns a List of Tensors and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. Returns: List: List (tl, tr, bl, br, center) Corresponding top left, top right, bottom left, bottom right and center crop. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') assert len(size) == 2, "Please provide only two dimensions (h, w) for size." _, image_width, image_height = img.size() crop_height, crop_width = size if crop_width > image_width or crop_height > image_height: msg = "Requested crop size {} is bigger than input size {}" raise ValueError(msg.format(size, (image_height, image_width))) tl = crop(img, 0, 0, crop_width, crop_height) tr = crop(img, image_width - crop_width, 0, image_width, crop_height) bl = crop(img, 0, image_height - crop_height, crop_width, image_height) br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height) center = center_crop(img, (crop_height, crop_width)) return [tl, tr, bl, br, center] def ten_crop(img, size, vertical_flip=False): # type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor] """Crop the given Image Tensor into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). .. Note:: This transform returns a List of images and there may be a mismatch in the number of inputs and targets your ``Dataset`` returns. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. vertical_flip (bool): Use vertical flipping instead of horizontal Returns: List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) Corresponding top left, top right, bottom left, bottom right and center crop and same for the flipped image's tensor. """ if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') assert len(size) == 2, "Please provide only two dimensions (h, w) for size." first_five = five_crop(img, size) if vertical_flip: img = vflip(img) else: img = hflip(img) second_five = five_crop(img, size) return first_five + second_five def _blend(img1, img2, ratio): # type: (Tensor, Tensor, float) -> Tensor bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) def _rgb2hsv(img): r, g, b = img.unbind(0) maxc = torch.max(img, dim=0).values minc = torch.min(img, dim=0).values # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN # from happening in the results, because # + S channel has division by `maxc`, which is zero only if `maxc = minc` # + H channel has division by `(maxc - minc)`. # # Instead of overwriting NaN afterwards, we just prevent it from occuring so # we don't need to deal with it in case we save the NaN in a buffer in # backprop, if it is ever supported, but it doesn't hurt to do so. eqc = maxc == minc cr = maxc - minc # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. s = cr / torch.where(eqc, maxc.new_ones(()), maxc) # Note that `eqc => maxc = minc = r = g = b`. So the following calculation # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it # would not matter what values `rc`, `gc`, and `bc` have here, and thus # replacing denominator with 1 when `eqc` is fine. cr_divisor = torch.where(eqc, maxc.new_ones(()), cr) rc = (maxc - r) / cr_divisor gc = (maxc - g) / cr_divisor bc = (maxc - b) / cr_divisor hr = (maxc == r) * (bc - gc) hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) h = (hr + hg + hb) h = torch.fmod((h / 6.0 + 1.0), 1.0) return torch.stack((h, s, maxc)) def _hsv2rgb(img): h, s, v = img.unbind(0) i = torch.floor(h * 6.0) f = (h * 6.0) - i i = i.to(dtype=torch.int32) p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) i = i % 6 mask = i == torch.arange(6)[:, None, None] a1 = torch.stack((v, q, p, p, t, v)) a2 = torch.stack((t, v, v, q, p, p)) a3 = torch.stack((p, p, t, v, v, q)) a4 = torch.stack((a1, a2, a3)) return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor: r"""Pad the given Tensor Image on all sides with specified padding mode and fill value. Args: img (Tensor): Image to be padded. padding (int or tuple or list): Padding on each border. If a single int is provided this is used to pad all borders. If a tuple or list of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple or list of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. In torchscript mode padding as single int is not supported, use a tuple or list of length 1: ``[padding, ]``. fill (int): Pixel fill value for constant fill. Default is 0. This value is only used when the padding_mode is constant padding_mode (str): Type of padding. Only "constant" is supported for Tensors as of now. - constant: pads with a constant value, this value is specified with fill Returns: Tensor: Padded image. """ if not _is_tensor_a_torch_image(img): raise TypeError("tensor is not a torch image.") if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (int, float)): raise TypeError("Got inappropriate fill arg") if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") if isinstance(padding, tuple): padding = list(padding) if isinstance(padding, list) and len(padding) not in [1, 2, 4]: raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) if padding_mode not in ["constant", ]: raise ValueError("Only constant padding_mode supported for torch tensors") if isinstance(padding, int): if torch.jit.is_scripting(): raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") pad_left = pad_right = pad_top = pad_bottom = padding elif len(padding) == 1: pad_left = pad_right = pad_top = pad_bottom = padding[0] elif len(padding) == 2: pad_left = pad_right = padding[0] pad_top = pad_bottom = padding[1] else: pad_left = padding[0] pad_top = padding[1] pad_right = padding[2] pad_bottom = padding[3] p = [pad_left, pad_right, pad_top, pad_bottom] img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill)) return img