Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import warnings import warnings
from typing import Optional, Tuple, List
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from typing import Optional, Tuple, List from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
def _is_tensor_a_torch_image(x: Tensor) -> bool: def _is_tensor_a_torch_image(x: Tensor) -> bool:
...@@ -97,7 +97,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -97,7 +97,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
# factor should be forced to int for torch jit script # factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results # otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1)) factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode='floor') image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype) return image.to(dtype)
else: else:
# factor should be forced to int for torch jit script # factor should be forced to int for torch jit script
...@@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: ...@@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
if left < 0 or top < 0 or right > w or bottom > h: if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
return pad(img[..., max(top, 0):bottom, max(left, 0):right], padding_ltrb, fill=0) return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right] return img[..., top:bottom, left:right]
...@@ -138,7 +138,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: ...@@ -138,7 +138,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
_assert_channels(img, [3]) _assert_channels(img, [3])
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3') raise ValueError("num_output_channels should be either 1 or 3")
r, g, b = img.unbind(dim=-3) r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one: # This implementation closely follows the TF one:
...@@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: ...@@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0: if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor)) raise ValueError("brightness_factor ({}) is not non-negative.".format(brightness_factor))
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: ...@@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0: if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor)) raise ValueError("contrast_factor ({}) is not non-negative.".format(contrast_factor))
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -182,10 +182,10 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -182,10 +182,10 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor))
if not (isinstance(img, torch.Tensor)): if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor image') raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0: if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor)) raise ValueError("saturation_factor ({}) is not non-negative.".format(saturation_factor))
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -225,12 +225,12 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: ...@@ -225,12 +225,12 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
raise TypeError('Input img should be a Tensor.') raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
if gamma < 0: if gamma < 0:
raise ValueError('Gamma should be a non-negative real number') raise ValueError("Gamma should be a non-negative real number")
result = img result = img
dtype = img.dtype dtype = img.dtype
...@@ -244,11 +244,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: ...@@ -244,11 +244,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""DEPRECATED """DEPRECATED"""
"""
warnings.warn( warnings.warn(
"This method is deprecated and will be removed in future releases. " "This method is deprecated and will be removed in future releases. " "Please, use ``F.center_crop`` instead."
"Please, use ``F.center_crop`` instead."
) )
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -268,11 +266,9 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: ...@@ -268,11 +266,9 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
"""DEPRECATED """DEPRECATED"""
"""
warnings.warn( warnings.warn(
"This method is deprecated and will be removed in future releases. " "This method is deprecated and will be removed in future releases. " "Please, use ``F.five_crop`` instead."
"Please, use ``F.five_crop`` instead."
) )
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -295,11 +291,9 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: ...@@ -295,11 +291,9 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]: def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
"""DEPRECATED """DEPRECATED"""
"""
warnings.warn( warnings.warn(
"This method is deprecated and will be removed in future releases. " "This method is deprecated and will be removed in future releases. " "Please, use ``F.ten_crop`` instead."
"Please, use ``F.ten_crop`` instead."
) )
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -357,7 +351,7 @@ def _rgb2hsv(img: Tensor) -> Tensor: ...@@ -357,7 +351,7 @@ def _rgb2hsv(img: Tensor) -> Tensor:
hr = (maxc == r) * (bc - gc) hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb) h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0) h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3) return torch.stack((h, s, maxc), dim=-3)
...@@ -389,7 +383,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: ...@@ -389,7 +383,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# crop if needed # crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
crop_left, crop_right, crop_top, crop_bottom = [-min(x, 0) for x in padding] crop_left, crop_right, crop_top, crop_bottom = [-min(x, 0) for x in padding]
img = img[..., crop_top:img.shape[-2] - crop_bottom, crop_left:img.shape[-1] - crop_right] img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding] padding = [max(x, 0) for x in padding]
in_sizes = img.size() in_sizes = img.size()
...@@ -427,8 +421,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -427,8 +421,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
padding = list(padding) padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]: if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + raise ValueError(
"{} element tuple".format(len(padding))) "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
...@@ -488,7 +483,7 @@ def resize( ...@@ -488,7 +483,7 @@ def resize(
size: List[int], size: List[int],
interpolation: str = "bilinear", interpolation: str = "bilinear",
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None antialias: Optional[bool] = None,
) -> Tensor: ) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -505,8 +500,9 @@ def resize( ...@@ -505,8 +500,9 @@ def resize(
if isinstance(size, list): if isinstance(size, list):
if len(size) not in [1, 2]: if len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " raise ValueError(
"{} element tuple/list".format(len(size))) "Size must be an int or a 1 or 2 element tuple/list, not a " "{} element tuple/list".format(len(size))
)
if max_size is not None and len(size) != 1: if max_size is not None and len(size) != 1:
raise ValueError( raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, " "max_size should only be passed if size specifies the length of the smaller edge, "
...@@ -594,8 +590,10 @@ def _assert_grid_transform_inputs( ...@@ -594,8 +590,10 @@ def _assert_grid_transform_inputs(
# Check fill # Check fill
num_channels = get_image_num_channels(img) num_channels = get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of " msg = (
"channels of the image ({} != {})") "The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels)) raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes: if interpolation not in supported_interpolation_modes:
...@@ -633,7 +631,12 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp ...@@ -633,7 +631,12 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
grid.dtype,
],
)
if img.shape[0] > 1: if img.shape[0] > 1:
# Apply same grid to a batch of images # Apply same grid to a batch of images
...@@ -653,7 +656,7 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L ...@@ -653,7 +656,7 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
mask = mask.expand_as(img) mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == 'nearest': if mode == "nearest":
mask = mask < 0.5 mask = mask < 0.5
img[mask] = fill_img[mask] img[mask] = fill_img[mask]
else: # 'bilinear' else: # 'bilinear'
...@@ -664,7 +667,11 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L ...@@ -664,7 +667,11 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
def _gen_affine_grid( def _gen_affine_grid(
theta: Tensor, w: int, h: int, ow: int, oh: int, theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor: ) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18 # AffineGridGenerator.cpp#L18
...@@ -686,7 +693,7 @@ def _gen_affine_grid( ...@@ -686,7 +693,7 @@ def _gen_affine_grid(
def affine( def affine(
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
...@@ -704,12 +711,14 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -704,12 +711,14 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
pts = torch.tensor([ pts = torch.tensor(
[-0.5 * w, -0.5 * h, 1.0], [
[-0.5 * w, 0.5 * h, 1.0], [-0.5 * w, -0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0], [-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0], [0.5 * w, 0.5 * h, 1.0],
]) [0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
min_vals, _ = new_pts.min(dim=0) min_vals, _ = new_pts.min(dim=0)
...@@ -724,8 +733,11 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -724,8 +733,11 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
def rotate( def rotate(
img: Tensor, matrix: List[float], interpolation: str = "nearest", img: Tensor,
expand: bool = False, fill: Optional[List[float]] = None matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[List[float]] = None,
) -> Tensor: ) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2] w, h = img.shape[-1], img.shape[-2]
...@@ -746,14 +758,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -746,14 +758,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# #
theta1 = torch.tensor([[ theta1 = torch.tensor(
[coeffs[0], coeffs[1], coeffs[2]], [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
[coeffs[3], coeffs[4], coeffs[5]] )
]], dtype=dtype, device=device) theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0]
]], dtype=dtype, device=device)
d = 0.5 d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
...@@ -775,7 +783,7 @@ def perspective( ...@@ -775,7 +783,7 @@ def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
if not (isinstance(img, torch.Tensor)): if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor.') raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -785,7 +793,7 @@ def perspective( ...@@ -785,7 +793,7 @@ def perspective(
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
supported_interpolation_modes=["nearest", "bilinear"], supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs coeffs=perspective_coeffs,
) )
ow, oh = img.shape[-1], img.shape[-2] ow, oh = img.shape[-1], img.shape[-2]
...@@ -805,7 +813,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: ...@@ -805,7 +813,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
def _get_gaussian_kernel2d( def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor: ) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
...@@ -815,7 +823,7 @@ def _get_gaussian_kernel2d( ...@@ -815,7 +823,7 @@ def _get_gaussian_kernel2d(
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)): if not (isinstance(img, torch.Tensor)):
raise TypeError('img should be Tensor. Got {}'.format(type(img))) raise TypeError("img should be Tensor. Got {}".format(type(img)))
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -823,7 +831,12 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te ...@@ -823,7 +831,12 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
# padding = (left, right, top, bottom) # padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
...@@ -857,7 +870,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: ...@@ -857,7 +870,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask return img & mask
...@@ -882,7 +895,12 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: ...@@ -882,7 +895,12 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
kernel /= kernel.sum() kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
...@@ -894,7 +912,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: ...@@ -894,7 +912,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0: if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor)) raise ValueError("sharpness_factor ({}) is not non-negative.".format(sharpness_factor))
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -939,13 +957,11 @@ def _scale_channel(img_chan: Tensor) -> Tensor: ...@@ -939,13 +957,11 @@ def _scale_channel(img_chan: Tensor) -> Tensor:
hist = torch.bincount(img_chan.view(-1), minlength=256) hist = torch.bincount(img_chan.view(-1), minlength=256)
nonzero_hist = hist[hist != 0] nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor') step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0: if step == 0:
return img_chan return img_chan
lut = torch.div( lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),
step, rounding_mode='floor')
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8) return lut[img_chan.to(torch.int64)].to(torch.uint8)
......
...@@ -17,12 +17,45 @@ from . import functional as F ...@@ -17,12 +17,45 @@ from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int from .functional import InterpolationMode, _interpolation_modes_from_int
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", __all__ = [
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "Compose",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "ToTensor",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "PILToTensor",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", "ConvertImageDtype",
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] "ToPILImage",
"Normalize",
"Resize",
"Scale",
"CenterCrop",
"Pad",
"Lambda",
"RandomApply",
"RandomChoice",
"RandomOrder",
"RandomCrop",
"RandomHorizontalFlip",
"RandomVerticalFlip",
"RandomResizedCrop",
"RandomSizedCrop",
"FiveCrop",
"TenCrop",
"LinearTransformation",
"ColorJitter",
"RandomRotation",
"RandomAffine",
"Grayscale",
"RandomGrayscale",
"RandomPerspective",
"RandomErasing",
"GaussianBlur",
"InterpolationMode",
"RandomInvert",
"RandomPosterize",
"RandomSolarize",
"RandomAdjustSharpness",
"RandomAutocontrast",
"RandomEqualize",
]
class Compose: class Compose:
...@@ -62,11 +95,11 @@ class Compose: ...@@ -62,11 +95,11 @@ class Compose:
return img return img
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
for t in self.transforms: for t in self.transforms:
format_string += '\n' format_string += "\n"
format_string += ' {0}'.format(t) format_string += " {0}".format(t)
format_string += '\n)' format_string += "\n)"
return format_string return format_string
...@@ -98,7 +131,7 @@ class ToTensor: ...@@ -98,7 +131,7 @@ class ToTensor:
return F.to_tensor(pic) return F.to_tensor(pic)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + "()"
class PILToTensor: class PILToTensor:
...@@ -118,7 +151,7 @@ class PILToTensor: ...@@ -118,7 +151,7 @@ class PILToTensor:
return F.pil_to_tensor(pic) return F.pil_to_tensor(pic)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + "()"
class ConvertImageDtype(torch.nn.Module): class ConvertImageDtype(torch.nn.Module):
...@@ -165,6 +198,7 @@ class ToPILImage: ...@@ -165,6 +198,7 @@ class ToPILImage:
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
""" """
def __init__(self, mode=None): def __init__(self, mode=None):
self.mode = mode self.mode = mode
...@@ -180,10 +214,10 @@ class ToPILImage: ...@@ -180,10 +214,10 @@ class ToPILImage:
return F.to_pil_image(pic, self.mode) return F.to_pil_image(pic, self.mode)
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
if self.mode is not None: if self.mode is not None:
format_string += 'mode={0}'.format(self.mode) format_string += "mode={0}".format(self.mode)
format_string += ')' format_string += ")"
return format_string return format_string
...@@ -222,7 +256,7 @@ class Normalize(torch.nn.Module): ...@@ -222,7 +256,7 @@ class Normalize(torch.nn.Module):
return F.normalize(tensor, self.mean, self.std, self.inplace) return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) return self.__class__.__name__ + "(mean={0}, std={1})".format(self.mean, self.std)
class Resize(torch.nn.Module): class Resize(torch.nn.Module):
...@@ -301,17 +335,20 @@ class Resize(torch.nn.Module): ...@@ -301,17 +335,20 @@ class Resize(torch.nn.Module):
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( return self.__class__.__name__ + "(size={0}, interpolation={1}, max_size={2}, antialias={3})".format(
self.size, interpolate_str, self.max_size, self.antialias) self.size, interpolate_str, self.max_size, self.antialias
)
class Scale(Resize): class Scale(Resize):
""" """
Note: This transform is deprecated in favor of Resize. Note: This transform is deprecated in favor of Resize.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " + warnings.warn(
"please use transforms.Resize instead.") "The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead."
)
super(Scale, self).__init__(*args, **kwargs) super(Scale, self).__init__(*args, **kwargs)
...@@ -342,7 +379,7 @@ class CenterCrop(torch.nn.Module): ...@@ -342,7 +379,7 @@ class CenterCrop(torch.nn.Module):
return F.center_crop(img, self.size) return F.center_crop(img, self.size)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) return self.__class__.__name__ + "(size={0})".format(self.size)
class Pad(torch.nn.Module): class Pad(torch.nn.Module):
...@@ -395,8 +432,9 @@ class Pad(torch.nn.Module): ...@@ -395,8 +432,9 @@ class Pad(torch.nn.Module):
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + raise ValueError(
"{} element tuple".format(len(padding))) "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))
)
self.padding = padding self.padding = padding
self.fill = fill self.fill = fill
...@@ -413,8 +451,9 @@ class Pad(torch.nn.Module): ...@@ -413,8 +451,9 @@ class Pad(torch.nn.Module):
return F.pad(img, self.padding, self.fill, self.padding_mode) return F.pad(img, self.padding, self.fill, self.padding_mode)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ return self.__class__.__name__ + "(padding={0}, fill={1}, padding_mode={2})".format(
format(self.padding, self.fill, self.padding_mode) self.padding, self.fill, self.padding_mode
)
class Lambda: class Lambda:
...@@ -433,7 +472,7 @@ class Lambda: ...@@ -433,7 +472,7 @@ class Lambda:
return self.lambd(img) return self.lambd(img)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '()' return self.__class__.__name__ + "()"
class RandomTransforms: class RandomTransforms:
...@@ -452,11 +491,11 @@ class RandomTransforms: ...@@ -452,11 +491,11 @@ class RandomTransforms:
raise NotImplementedError() raise NotImplementedError()
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
for t in self.transforms: for t in self.transforms:
format_string += '\n' format_string += "\n"
format_string += ' {0}'.format(t) format_string += " {0}".format(t)
format_string += '\n)' format_string += "\n)"
return format_string return format_string
...@@ -493,18 +532,18 @@ class RandomApply(torch.nn.Module): ...@@ -493,18 +532,18 @@ class RandomApply(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
format_string += '\n p={}'.format(self.p) format_string += "\n p={}".format(self.p)
for t in self.transforms: for t in self.transforms:
format_string += '\n' format_string += "\n"
format_string += ' {0}'.format(t) format_string += " {0}".format(t)
format_string += '\n)' format_string += "\n)"
return format_string return format_string
class RandomOrder(RandomTransforms): class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order. This transform does not support torchscript. """Apply a list of transformations in a random order. This transform does not support torchscript."""
"""
def __call__(self, img): def __call__(self, img):
order = list(range(len(self.transforms))) order = list(range(len(self.transforms)))
random.shuffle(order) random.shuffle(order)
...@@ -514,8 +553,8 @@ class RandomOrder(RandomTransforms): ...@@ -514,8 +553,8 @@ class RandomOrder(RandomTransforms):
class RandomChoice(RandomTransforms): class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript. """Apply single transformation randomly picked from a list. This transform does not support torchscript."""
"""
def __init__(self, transforms, p=None): def __init__(self, transforms, p=None):
super().__init__(transforms) super().__init__(transforms)
if p is not None and not isinstance(p, Sequence): if p is not None and not isinstance(p, Sequence):
...@@ -528,7 +567,7 @@ class RandomChoice(RandomTransforms): ...@@ -528,7 +567,7 @@ class RandomChoice(RandomTransforms):
def __repr__(self): def __repr__(self):
format_string = super().__repr__() format_string = super().__repr__()
format_string += '(p={0})'.format(self.p) format_string += "(p={0})".format(self.p)
return format_string return format_string
...@@ -591,23 +630,19 @@ class RandomCrop(torch.nn.Module): ...@@ -591,23 +630,19 @@ class RandomCrop(torch.nn.Module):
th, tw = output_size th, tw = output_size
if h + 1 < th or w + 1 < tw: if h + 1 < th or w + 1 < tw:
raise ValueError( raise ValueError("Required crop size {} is larger then input image size {}".format((th, tw), (h, w)))
"Required crop size {} is larger then input image size {}".format((th, tw), (h, w))
)
if w == tw and h == th: if w == tw and h == th:
return 0, 0, h, w return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1, )).item() i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item() j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw return i, j, th, tw
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__() super().__init__()
self.size = tuple(_setup_size( self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
size, error_msg="Please provide only two dimensions (h, w) for size."
))
self.padding = padding self.padding = padding
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
...@@ -670,7 +705,7 @@ class RandomHorizontalFlip(torch.nn.Module): ...@@ -670,7 +705,7 @@ class RandomHorizontalFlip(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + "(p={})".format(self.p)
class RandomVerticalFlip(torch.nn.Module): class RandomVerticalFlip(torch.nn.Module):
...@@ -700,7 +735,7 @@ class RandomVerticalFlip(torch.nn.Module): ...@@ -700,7 +735,7 @@ class RandomVerticalFlip(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + "(p={})".format(self.p)
class RandomPerspective(torch.nn.Module): class RandomPerspective(torch.nn.Module):
...@@ -780,27 +815,27 @@ class RandomPerspective(torch.nn.Module): ...@@ -780,27 +815,27 @@ class RandomPerspective(torch.nn.Module):
half_height = height // 2 half_height = height // 2
half_width = width // 2 half_width = width // 2
topleft = [ topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
] ]
topright = [ topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
] ]
botright = [ botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
] ]
botleft = [ botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
] ]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft] endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints return startpoints, endpoints
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + "(p={})".format(self.p)
class RandomResizedCrop(torch.nn.Module): class RandomResizedCrop(torch.nn.Module):
...@@ -832,7 +867,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -832,7 +867,7 @@ class RandomResizedCrop(torch.nn.Module):
""" """
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
...@@ -856,9 +891,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -856,9 +891,7 @@ class RandomResizedCrop(torch.nn.Module):
self.ratio = ratio self.ratio = ratio
@staticmethod @staticmethod
def get_params( def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
img: Tensor, scale: List[float], ratio: List[float]
) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop. """Get parameters for ``crop`` for a random sized crop.
Args: Args:
...@@ -876,9 +909,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -876,9 +909,7 @@ class RandomResizedCrop(torch.nn.Module):
log_ratio = torch.log(torch.tensor(ratio)) log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10): for _ in range(10):
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.exp( aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
w = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio)))
...@@ -916,10 +947,10 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -916,10 +947,10 @@ class RandomResizedCrop(torch.nn.Module):
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string = self.__class__.__name__ + "(size={0}".format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str) format_string += ", interpolation={0})".format(interpolate_str)
return format_string return format_string
...@@ -927,9 +958,12 @@ class RandomSizedCrop(RandomResizedCrop): ...@@ -927,9 +958,12 @@ class RandomSizedCrop(RandomResizedCrop):
""" """
Note: This transform is deprecated in favor of RandomResizedCrop. Note: This transform is deprecated in favor of RandomResizedCrop.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + warnings.warn(
"please use transforms.RandomResizedCrop instead.") "The use of the transforms.RandomSizedCrop transform is deprecated, "
+ "please use transforms.RandomResizedCrop instead."
)
super(RandomSizedCrop, self).__init__(*args, **kwargs) super(RandomSizedCrop, self).__init__(*args, **kwargs)
...@@ -976,7 +1010,7 @@ class FiveCrop(torch.nn.Module): ...@@ -976,7 +1010,7 @@ class FiveCrop(torch.nn.Module):
return F.five_crop(img, self.size) return F.five_crop(img, self.size)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) return self.__class__.__name__ + "(size={0})".format(self.size)
class TenCrop(torch.nn.Module): class TenCrop(torch.nn.Module):
...@@ -1025,7 +1059,7 @@ class TenCrop(torch.nn.Module): ...@@ -1025,7 +1059,7 @@ class TenCrop(torch.nn.Module):
return F.ten_crop(img, self.size, self.vertical_flip) return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) return self.__class__.__name__ + "(size={0}, vertical_flip={1})".format(self.size, self.vertical_flip)
class LinearTransformation(torch.nn.Module): class LinearTransformation(torch.nn.Module):
...@@ -1050,17 +1084,25 @@ class LinearTransformation(torch.nn.Module): ...@@ -1050,17 +1084,25 @@ class LinearTransformation(torch.nn.Module):
def __init__(self, transformation_matrix, mean_vector): def __init__(self, transformation_matrix, mean_vector):
super().__init__() super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1): if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " + raise ValueError(
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) "transformation_matrix should be square. Got "
+ "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())
)
if mean_vector.size(0) != transformation_matrix.size(0): if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + raise ValueError(
" as any one of the dimensions of the transformation_matrix [{}]" "mean_vector should have the same length {}".format(mean_vector.size(0))
.format(tuple(transformation_matrix.size()))) + " as any one of the dimensions of the transformation_matrix [{}]".format(
tuple(transformation_matrix.size())
)
)
if transformation_matrix.device != mean_vector.device: if transformation_matrix.device != mean_vector.device:
raise ValueError("Input tensors should be on the same device. Got {} and {}" raise ValueError(
.format(transformation_matrix.device, mean_vector.device)) "Input tensors should be on the same device. Got {} and {}".format(
transformation_matrix.device, mean_vector.device
)
)
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
...@@ -1076,13 +1118,17 @@ class LinearTransformation(torch.nn.Module): ...@@ -1076,13 +1118,17 @@ class LinearTransformation(torch.nn.Module):
shape = tensor.shape shape = tensor.shape
n = shape[-3] * shape[-2] * shape[-1] n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]: if n != self.transformation_matrix.shape[0]:
raise ValueError("Input tensor and transformation matrix have incompatible shape." + raise ValueError(
"[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + "Input tensor and transformation matrix have incompatible shape."
"{}".format(self.transformation_matrix.shape[0])) + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1])
+ "{}".format(self.transformation_matrix.shape[0])
)
if tensor.device.type != self.mean_vector.device.type: if tensor.device.type != self.mean_vector.device.type:
raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " raise ValueError(
"Got {} vs {}".format(tensor.device, self.mean_vector.device)) "Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device)
)
flat_tensor = tensor.view(-1, n) - self.mean_vector flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
...@@ -1090,9 +1136,9 @@ class LinearTransformation(torch.nn.Module): ...@@ -1090,9 +1136,9 @@ class LinearTransformation(torch.nn.Module):
return tensor return tensor
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(transformation_matrix=' format_string = self.__class__.__name__ + "(transformation_matrix="
format_string += (str(self.transformation_matrix.tolist()) + ')') format_string += str(self.transformation_matrix.tolist()) + ")"
format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")"
return format_string return format_string
...@@ -1119,14 +1165,13 @@ class ColorJitter(torch.nn.Module): ...@@ -1119,14 +1165,13 @@ class ColorJitter(torch.nn.Module):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__() super().__init__()
self.brightness = self._check_input(brightness, 'brightness') self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, 'contrast') self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, 'saturation') self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
clip_first_on_zero=False)
@torch.jit.unused @torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
if isinstance(value, numbers.Number): if isinstance(value, numbers.Number):
if value < 0: if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name)) raise ValueError("If {} is a single number, it must be non negative.".format(name))
...@@ -1146,11 +1191,12 @@ class ColorJitter(torch.nn.Module): ...@@ -1146,11 +1191,12 @@ class ColorJitter(torch.nn.Module):
return value return value
@staticmethod @staticmethod
def get_params(brightness: Optional[List[float]], def get_params(
contrast: Optional[List[float]], brightness: Optional[List[float]],
saturation: Optional[List[float]], contrast: Optional[List[float]],
hue: Optional[List[float]] saturation: Optional[List[float]],
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: hue: Optional[List[float]],
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image. """Get the parameters for the randomized transform to be applied on image.
Args: Args:
...@@ -1184,8 +1230,9 @@ class ColorJitter(torch.nn.Module): ...@@ -1184,8 +1230,9 @@ class ColorJitter(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Color jittered image. PIL Image or Tensor: Color jittered image.
""" """
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
self.get_params(self.brightness, self.contrast, self.saturation, self.hue) self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx: for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None: if fn_id == 0 and brightness_factor is not None:
...@@ -1200,11 +1247,11 @@ class ColorJitter(torch.nn.Module): ...@@ -1200,11 +1247,11 @@ class ColorJitter(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
format_string += 'brightness={0}'.format(self.brightness) format_string += "brightness={0}".format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast) format_string += ", contrast={0}".format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation) format_string += ", saturation={0}".format(self.saturation)
format_string += ', hue={0})'.format(self.hue) format_string += ", hue={0})".format(self.hue)
return format_string return format_string
...@@ -1254,10 +1301,10 @@ class RandomRotation(torch.nn.Module): ...@@ -1254,10 +1301,10 @@ class RandomRotation(torch.nn.Module):
) )
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if center is not None: if center is not None:
_check_sequence_input(center, "center", req_sizes=(2, )) _check_sequence_input(center, "center", req_sizes=(2,))
self.center = center self.center = center
...@@ -1301,14 +1348,14 @@ class RandomRotation(torch.nn.Module): ...@@ -1301,14 +1348,14 @@ class RandomRotation(torch.nn.Module):
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees)
format_string += ', interpolation={0}'.format(interpolate_str) format_string += ", interpolation={0}".format(interpolate_str)
format_string += ', expand={0}'.format(self.expand) format_string += ", expand={0}".format(self.expand)
if self.center is not None: if self.center is not None:
format_string += ', center={0}'.format(self.center) format_string += ", center={0}".format(self.center)
if self.fill is not None: if self.fill is not None:
format_string += ', fill={0}'.format(self.fill) format_string += ", fill={0}".format(self.fill)
format_string += ')' format_string += ")"
return format_string return format_string
...@@ -1349,8 +1396,15 @@ class RandomAffine(torch.nn.Module): ...@@ -1349,8 +1396,15 @@ class RandomAffine(torch.nn.Module):
""" """
def __init__( def __init__(
self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, self,
fillcolor=None, resample=None degrees,
translate=None,
scale=None,
shear=None,
interpolation=InterpolationMode.NEAREST,
fill=0,
fillcolor=None,
resample=None,
): ):
super().__init__() super().__init__()
if resample is not None: if resample is not None:
...@@ -1373,17 +1427,17 @@ class RandomAffine(torch.nn.Module): ...@@ -1373,17 +1427,17 @@ class RandomAffine(torch.nn.Module):
) )
fill = fillcolor fill = fillcolor
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None: if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2, )) _check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate: for t in translate:
if not (0.0 <= t <= 1.0): if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1") raise ValueError("translation values should be between 0 and 1")
self.translate = translate self.translate = translate
if scale is not None: if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2, )) _check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale: for s in scale:
if s <= 0: if s <= 0:
raise ValueError("scale values should be positive") raise ValueError("scale values should be positive")
...@@ -1405,11 +1459,11 @@ class RandomAffine(torch.nn.Module): ...@@ -1405,11 +1459,11 @@ class RandomAffine(torch.nn.Module):
@staticmethod @staticmethod
def get_params( def get_params(
degrees: List[float], degrees: List[float],
translate: Optional[List[float]], translate: Optional[List[float]],
scale_ranges: Optional[List[float]], scale_ranges: Optional[List[float]],
shears: Optional[List[float]], shears: Optional[List[float]],
img_size: List[int] img_size: List[int],
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
"""Get parameters for affine transformation """Get parameters for affine transformation
...@@ -1462,20 +1516,20 @@ class RandomAffine(torch.nn.Module): ...@@ -1462,20 +1516,20 @@ class RandomAffine(torch.nn.Module):
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
def __repr__(self): def __repr__(self):
s = '{name}(degrees={degrees}' s = "{name}(degrees={degrees}"
if self.translate is not None: if self.translate is not None:
s += ', translate={translate}' s += ", translate={translate}"
if self.scale is not None: if self.scale is not None:
s += ', scale={scale}' s += ", scale={scale}"
if self.shear is not None: if self.shear is not None:
s += ', shear={shear}' s += ", shear={shear}"
if self.interpolation != InterpolationMode.NEAREST: if self.interpolation != InterpolationMode.NEAREST:
s += ', interpolation={interpolation}' s += ", interpolation={interpolation}"
if self.fill != 0: if self.fill != 0:
s += ', fill={fill}' s += ", fill={fill}"
s += ')' s += ")"
d = dict(self.__dict__) d = dict(self.__dict__)
d['interpolation'] = self.interpolation.value d["interpolation"] = self.interpolation.value
return s.format(name=self.__class__.__name__, **d) return s.format(name=self.__class__.__name__, **d)
...@@ -1510,7 +1564,7 @@ class Grayscale(torch.nn.Module): ...@@ -1510,7 +1564,7 @@ class Grayscale(torch.nn.Module):
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) return self.__class__.__name__ + "(num_output_channels={0})".format(self.num_output_channels)
class RandomGrayscale(torch.nn.Module): class RandomGrayscale(torch.nn.Module):
...@@ -1547,11 +1601,11 @@ class RandomGrayscale(torch.nn.Module): ...@@ -1547,11 +1601,11 @@ class RandomGrayscale(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={0})'.format(self.p) return self.__class__.__name__ + "(p={0})".format(self.p)
class RandomErasing(torch.nn.Module): class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an torch Tensor image and erases its pixels. """Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
This transform does not support PIL Image. This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
...@@ -1603,7 +1657,7 @@ class RandomErasing(torch.nn.Module): ...@@ -1603,7 +1657,7 @@ class RandomErasing(torch.nn.Module):
@staticmethod @staticmethod
def get_params( def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
) -> Tuple[int, int, int, int, Tensor]: ) -> Tuple[int, int, int, int, Tensor]:
"""Get parameters for ``erase`` for a random erasing. """Get parameters for ``erase`` for a random erasing.
...@@ -1624,9 +1678,7 @@ class RandomErasing(torch.nn.Module): ...@@ -1624,9 +1678,7 @@ class RandomErasing(torch.nn.Module):
log_ratio = torch.log(torch.tensor(ratio)) log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10): for _ in range(10):
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.exp( aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio))) h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio)))
...@@ -1638,8 +1690,8 @@ class RandomErasing(torch.nn.Module): ...@@ -1638,8 +1690,8 @@ class RandomErasing(torch.nn.Module):
else: else:
v = torch.tensor(value)[:, None, None] v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1, )).item() i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1, )).item() j = torch.randint(0, img_w - w + 1, size=(1,)).item()
return i, j, h, w, v return i, j, h, w, v
# Return original image # Return original image
...@@ -1657,7 +1709,9 @@ class RandomErasing(torch.nn.Module): ...@@ -1657,7 +1709,9 @@ class RandomErasing(torch.nn.Module):
# cast self.value to script acceptable type # cast self.value to script acceptable type
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value, ] value = [
self.value,
]
elif isinstance(self.value, str): elif isinstance(self.value, str):
value = None value = None
elif isinstance(self.value, tuple): elif isinstance(self.value, tuple):
...@@ -1676,11 +1730,11 @@ class RandomErasing(torch.nn.Module): ...@@ -1676,11 +1730,11 @@ class RandomErasing(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
s = '(p={}, '.format(self.p) s = "(p={}, ".format(self.p)
s += 'scale={}, '.format(self.scale) s += "scale={}, ".format(self.scale)
s += 'ratio={}, '.format(self.ratio) s += "ratio={}, ".format(self.ratio)
s += 'value={}, '.format(self.value) s += "value={}, ".format(self.value)
s += 'inplace={})'.format(self.inplace) s += "inplace={})".format(self.inplace)
return self.__class__.__name__ + s return self.__class__.__name__ + s
...@@ -1713,7 +1767,7 @@ class GaussianBlur(torch.nn.Module): ...@@ -1713,7 +1767,7 @@ class GaussianBlur(torch.nn.Module):
raise ValueError("If sigma is a single number, it must be positive.") raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma) sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2: elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0. < sigma[0] <= sigma[1]: if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).") raise ValueError("sigma values should be positive and of the form (min, max).")
else: else:
raise ValueError("sigma should be a single number or a list/tuple with length 2.") raise ValueError("sigma should be a single number or a list/tuple with length 2.")
...@@ -1745,8 +1799,8 @@ class GaussianBlur(torch.nn.Module): ...@@ -1745,8 +1799,8 @@ class GaussianBlur(torch.nn.Module):
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
def __repr__(self): def __repr__(self):
s = '(kernel_size={}, '.format(self.kernel_size) s = "(kernel_size={}, ".format(self.kernel_size)
s += 'sigma={})'.format(self.sigma) s += "sigma={})".format(self.sigma)
return self.__class__.__name__ + s return self.__class__.__name__ + s
...@@ -1771,7 +1825,7 @@ def _check_sequence_input(x, name, req_sizes): ...@@ -1771,7 +1825,7 @@ def _check_sequence_input(x, name, req_sizes):
raise ValueError("{} should be sequence of length {}.".format(name, msg)) raise ValueError("{} should be sequence of length {}.".format(name, msg))
def _setup_angle(x, name, req_sizes=(2, )): def _setup_angle(x, name, req_sizes=(2,)):
if isinstance(x, numbers.Number): if isinstance(x, numbers.Number):
if x < 0: if x < 0:
raise ValueError("If {} is a single number, it must be positive.".format(name)) raise ValueError("If {} is a single number, it must be positive.".format(name))
...@@ -1809,7 +1863,7 @@ class RandomInvert(torch.nn.Module): ...@@ -1809,7 +1863,7 @@ class RandomInvert(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + "(p={})".format(self.p)
class RandomPosterize(torch.nn.Module): class RandomPosterize(torch.nn.Module):
...@@ -1841,7 +1895,7 @@ class RandomPosterize(torch.nn.Module): ...@@ -1841,7 +1895,7 @@ class RandomPosterize(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) return self.__class__.__name__ + "(bits={},p={})".format(self.bits, self.p)
class RandomSolarize(torch.nn.Module): class RandomSolarize(torch.nn.Module):
...@@ -1873,7 +1927,7 @@ class RandomSolarize(torch.nn.Module): ...@@ -1873,7 +1927,7 @@ class RandomSolarize(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) return self.__class__.__name__ + "(threshold={},p={})".format(self.threshold, self.p)
class RandomAdjustSharpness(torch.nn.Module): class RandomAdjustSharpness(torch.nn.Module):
...@@ -1905,7 +1959,7 @@ class RandomAdjustSharpness(torch.nn.Module): ...@@ -1905,7 +1959,7 @@ class RandomAdjustSharpness(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) return self.__class__.__name__ + "(sharpness_factor={},p={})".format(self.sharpness_factor, self.p)
class RandomAutocontrast(torch.nn.Module): class RandomAutocontrast(torch.nn.Module):
...@@ -1935,7 +1989,7 @@ class RandomAutocontrast(torch.nn.Module): ...@@ -1935,7 +1989,7 @@ class RandomAutocontrast(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + "(p={})".format(self.p)
class RandomEqualize(torch.nn.Module): class RandomEqualize(torch.nn.Module):
...@@ -1965,4 +2019,4 @@ class RandomEqualize(torch.nn.Module): ...@@ -1965,4 +2019,4 @@ class RandomEqualize(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + "(p={})".format(self.p)
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import pathlib
import torch
import math import math
import pathlib
import warnings import warnings
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import numpy as np import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont, ImageColor from PIL import Image, ImageDraw, ImageFont, ImageColor
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] __all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
...@@ -18,7 +19,7 @@ def make_grid( ...@@ -18,7 +19,7 @@ def make_grid(
value_range: Optional[Tuple[int, int]] = None, value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False, scale_each: bool = False,
pad_value: int = 0, pad_value: int = 0,
**kwargs **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Make a grid of images. Make a grid of images.
...@@ -41,9 +42,8 @@ def make_grid( ...@@ -41,9 +42,8 @@ def make_grid(
Returns: Returns:
grid (Tensor): the tensor containing grid of images. grid (Tensor): the tensor containing grid of images.
""" """
if not (torch.is_tensor(tensor) or if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if "range" in kwargs.keys(): if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead." warning = "range will be deprecated, please use value_range instead."
...@@ -67,8 +67,9 @@ def make_grid( ...@@ -67,8 +67,9 @@ def make_grid(
if normalize is True: if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place tensor = tensor.clone() # avoid modifying tensor in-place
if value_range is not None: if value_range is not None:
assert isinstance(value_range, tuple), \ assert isinstance(
"value_range has to be a tuple (min, max) if specified. min and max are numbers" value_range, tuple
), "value_range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, low, high): def norm_ip(img, low, high):
img.clamp_(min=low, max=high) img.clamp_(min=low, max=high)
...@@ -115,7 +116,7 @@ def save_image( ...@@ -115,7 +116,7 @@ def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]], tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO], fp: Union[Text, pathlib.Path, BinaryIO],
format: Optional[str] = None, format: Optional[str] = None,
**kwargs **kwargs,
) -> None: ) -> None:
""" """
Save a given Tensor into an image file. Save a given Tensor into an image file.
...@@ -131,7 +132,7 @@ def save_image( ...@@ -131,7 +132,7 @@ def save_image(
grid = make_grid(tensor, **kwargs) grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr) im = Image.fromarray(ndarr)
im.save(fp, format=format) im.save(fp, format=format)
...@@ -145,7 +146,7 @@ def draw_bounding_boxes( ...@@ -145,7 +146,7 @@ def draw_bounding_boxes(
fill: Optional[bool] = False, fill: Optional[bool] = False,
width: int = 1, width: int = 1,
font: Optional[str] = None, font: Optional[str] = None,
font_size: int = 10 font_size: int = 10,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment