"sgl-kernel/vscode:/vscode.git/clone" did not exist on "4540a4666a112a82dcf21505b781f3e31e50d178"
Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import warnings
from typing import Optional, Tuple, List
import torch
from torch import Tensor
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
from torch.jit.annotations import BroadcastingList2
from typing import Optional, Tuple, List
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
def _is_tensor_a_torch_image(x: Tensor) -> bool:
......@@ -97,7 +97,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode='floor')
image = torch.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
......@@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)]
return pad(img[..., max(top, 0):bottom, max(left, 0):right], padding_ltrb, fill=0)
return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0)
return img[..., top:bottom, left:right]
......@@ -138,7 +138,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
_assert_channels(img, [3])
if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
raise ValueError("num_output_channels should be either 1 or 3")
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
......@@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))
raise ValueError("brightness_factor ({}) is not non-negative.".format(brightness_factor))
_assert_image_tensor(img)
......@@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))
raise ValueError("contrast_factor ({}) is not non-negative.".format(contrast_factor))
_assert_image_tensor(img)
......@@ -182,10 +182,10 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor))
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor image')
raise TypeError("Input img should be Tensor image")
_assert_image_tensor(img)
......@@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))
raise ValueError("saturation_factor ({}) is not non-negative.".format(saturation_factor))
_assert_image_tensor(img)
......@@ -225,12 +225,12 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError('Input img should be a Tensor.')
raise TypeError("Input img should be a Tensor.")
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
raise ValueError("Gamma should be a non-negative real number")
result = img
dtype = img.dtype
......@@ -244,11 +244,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""DEPRECATED
"""
"""DEPRECATED"""
warnings.warn(
"This method is deprecated and will be removed in future releases. "
"Please, use ``F.center_crop`` instead."
"This method is deprecated and will be removed in future releases. " "Please, use ``F.center_crop`` instead."
)
_assert_image_tensor(img)
......@@ -268,11 +266,9 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
"""DEPRECATED
"""
"""DEPRECATED"""
warnings.warn(
"This method is deprecated and will be removed in future releases. "
"Please, use ``F.five_crop`` instead."
"This method is deprecated and will be removed in future releases. " "Please, use ``F.five_crop`` instead."
)
_assert_image_tensor(img)
......@@ -295,11 +291,9 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
"""DEPRECATED
"""
"""DEPRECATED"""
warnings.warn(
"This method is deprecated and will be removed in future releases. "
"Please, use ``F.ten_crop`` instead."
"This method is deprecated and will be removed in future releases. " "Please, use ``F.ten_crop`` instead."
)
_assert_image_tensor(img)
......@@ -357,7 +351,7 @@ def _rgb2hsv(img: Tensor) -> Tensor:
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb)
h = hr + hg + hb
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
......@@ -389,7 +383,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
crop_left, crop_right, crop_top, crop_bottom = [-min(x, 0) for x in padding]
img = img[..., crop_top:img.shape[-2] - crop_bottom, crop_left:img.shape[-1] - crop_right]
img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
......@@ -427,8 +421,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
......@@ -488,7 +483,7 @@ def resize(
size: List[int],
interpolation: str = "bilinear",
max_size: Optional[int] = None,
antialias: Optional[bool] = None
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)
......@@ -505,8 +500,9 @@ def resize(
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
raise ValueError(
"Size must be an int or a 1 or 2 element tuple/list, not a " "{} element tuple/list".format(len(size))
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
......@@ -594,8 +590,10 @@ def _assert_grid_transform_inputs(
# Check fill
num_channels = get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})")
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
......@@ -633,7 +631,12 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
grid.dtype,
],
)
if img.shape[0] > 1:
# Apply same grid to a batch of images
......@@ -653,7 +656,7 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == 'nearest':
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
......@@ -664,7 +667,11 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L
def _gen_affine_grid(
theta: Tensor, w: int, h: int, ow: int, oh: int,
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
......@@ -686,7 +693,7 @@ def _gen_affine_grid(
def affine(
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
......@@ -704,12 +711,14 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
pts = torch.tensor([
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
])
pts = torch.tensor(
[
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
min_vals, _ = new_pts.min(dim=0)
......@@ -724,8 +733,11 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
def rotate(
img: Tensor, matrix: List[float], interpolation: str = "nearest",
expand: bool = False, fill: Optional[List[float]] = None
img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
expand: bool = False,
fill: Optional[List[float]] = None,
) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
......@@ -746,14 +758,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor([[
[coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]]
]], dtype=dtype, device=device)
theta2 = torch.tensor([[
[coeffs[6], coeffs[7], 1.0],
[coeffs[6], coeffs[7], 1.0]
]], dtype=dtype, device=device)
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
......@@ -775,7 +783,7 @@ def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor.')
raise TypeError("Input img should be Tensor.")
_assert_image_tensor(img)
......@@ -785,7 +793,7 @@ def perspective(
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs
coeffs=perspective_coeffs,
)
ow, oh = img.shape[-1], img.shape[-2]
......@@ -805,7 +813,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
......@@ -815,7 +823,7 @@ def _get_gaussian_kernel2d(
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError('img should be Tensor. Got {}'.format(type(img)))
raise TypeError("img should be Tensor. Got {}".format(type(img)))
_assert_image_tensor(img)
......@@ -823,7 +831,12 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
......@@ -857,7 +870,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
_assert_channels(img, [1, 3])
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
......@@ -882,7 +895,12 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img,
[
kernel.dtype,
],
)
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
......@@ -894,7 +912,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor:
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))
raise ValueError("sharpness_factor ({}) is not non-negative.".format(sharpness_factor))
_assert_image_tensor(img)
......@@ -939,13 +957,11 @@ def _scale_channel(img_chan: Tensor) -> Tensor:
hist = torch.bincount(img_chan.view(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")
if step == 0:
return img_chan
lut = torch.div(
torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),
step, rounding_mode='floor')
lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor")
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
......
......@@ -17,12 +17,45 @@ from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
__all__ = [
"Compose",
"ToTensor",
"PILToTensor",
"ConvertImageDtype",
"ToPILImage",
"Normalize",
"Resize",
"Scale",
"CenterCrop",
"Pad",
"Lambda",
"RandomApply",
"RandomChoice",
"RandomOrder",
"RandomCrop",
"RandomHorizontalFlip",
"RandomVerticalFlip",
"RandomResizedCrop",
"RandomSizedCrop",
"FiveCrop",
"TenCrop",
"LinearTransformation",
"ColorJitter",
"RandomRotation",
"RandomAffine",
"Grayscale",
"RandomGrayscale",
"RandomPerspective",
"RandomErasing",
"GaussianBlur",
"InterpolationMode",
"RandomInvert",
"RandomPosterize",
"RandomSolarize",
"RandomAdjustSharpness",
"RandomAutocontrast",
"RandomEqualize",
]
class Compose:
......@@ -62,11 +95,11 @@ class Compose:
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
......@@ -98,7 +131,7 @@ class ToTensor:
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
return self.__class__.__name__ + "()"
class PILToTensor:
......@@ -118,7 +151,7 @@ class PILToTensor:
return F.pil_to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
return self.__class__.__name__ + "()"
class ConvertImageDtype(torch.nn.Module):
......@@ -165,6 +198,7 @@ class ToPILImage:
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
def __init__(self, mode=None):
self.mode = mode
......@@ -180,10 +214,10 @@ class ToPILImage:
return F.to_pil_image(pic, self.mode)
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string = self.__class__.__name__ + "("
if self.mode is not None:
format_string += 'mode={0}'.format(self.mode)
format_string += ')'
format_string += "mode={0}".format(self.mode)
format_string += ")"
return format_string
......@@ -222,7 +256,7 @@ class Normalize(torch.nn.Module):
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
return self.__class__.__name__ + "(mean={0}, std={1})".format(self.mean, self.std)
class Resize(torch.nn.Module):
......@@ -301,17 +335,20 @@ class Resize(torch.nn.Module):
def __repr__(self):
interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
self.size, interpolate_str, self.max_size, self.antialias)
return self.__class__.__name__ + "(size={0}, interpolation={1}, max_size={2}, antialias={3})".format(
self.size, interpolate_str, self.max_size, self.antialias
)
class Scale(Resize):
"""
Note: This transform is deprecated in favor of Resize.
"""
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
warnings.warn(
"The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead."
)
super(Scale, self).__init__(*args, **kwargs)
......@@ -342,7 +379,7 @@ class CenterCrop(torch.nn.Module):
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
return self.__class__.__name__ + "(size={0})".format(self.size)
class Pad(torch.nn.Module):
......@@ -395,8 +432,9 @@ class Pad(torch.nn.Module):
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))
)
self.padding = padding
self.fill = fill
......@@ -413,8 +451,9 @@ class Pad(torch.nn.Module):
return F.pad(img, self.padding, self.fill, self.padding_mode)
def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
format(self.padding, self.fill, self.padding_mode)
return self.__class__.__name__ + "(padding={0}, fill={1}, padding_mode={2})".format(
self.padding, self.fill, self.padding_mode
)
class Lambda:
......@@ -433,7 +472,7 @@ class Lambda:
return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + '()'
return self.__class__.__name__ + "()"
class RandomTransforms:
......@@ -452,11 +491,11 @@ class RandomTransforms:
raise NotImplementedError()
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
......@@ -493,18 +532,18 @@ class RandomApply(torch.nn.Module):
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
format_string = self.__class__.__name__ + "("
format_string += "\n p={}".format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order. This transform does not support torchscript.
"""
"""Apply a list of transformations in a random order. This transform does not support torchscript."""
def __call__(self, img):
order = list(range(len(self.transforms)))
random.shuffle(order)
......@@ -514,8 +553,8 @@ class RandomOrder(RandomTransforms):
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
"""
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
......@@ -528,7 +567,7 @@ class RandomChoice(RandomTransforms):
def __repr__(self):
format_string = super().__repr__()
format_string += '(p={0})'.format(self.p)
format_string += "(p={0})".format(self.p)
return format_string
......@@ -591,23 +630,19 @@ class RandomCrop(torch.nn.Module):
th, tw = output_size
if h + 1 < th or w + 1 < tw:
raise ValueError(
"Required crop size {} is larger then input image size {}".format((th, tw), (h, w))
)
raise ValueError("Required crop size {} is larger then input image size {}".format((th, tw), (h, w)))
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__()
self.size = tuple(_setup_size(
size, error_msg="Please provide only two dimensions (h, w) for size."
))
self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.padding = padding
self.pad_if_needed = pad_if_needed
......@@ -670,7 +705,7 @@ class RandomHorizontalFlip(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
return self.__class__.__name__ + "(p={})".format(self.p)
class RandomVerticalFlip(torch.nn.Module):
......@@ -700,7 +735,7 @@ class RandomVerticalFlip(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
return self.__class__.__name__ + "(p={})".format(self.p)
class RandomPerspective(torch.nn.Module):
......@@ -780,27 +815,27 @@ class RandomPerspective(torch.nn.Module):
half_height = height // 2
half_width = width // 2
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
return self.__class__.__name__ + "(p={})".format(self.p)
class RandomResizedCrop(torch.nn.Module):
......@@ -832,7 +867,7 @@ class RandomResizedCrop(torch.nn.Module):
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
......@@ -856,9 +891,7 @@ class RandomResizedCrop(torch.nn.Module):
self.ratio = ratio
@staticmethod
def get_params(
img: Tensor, scale: List[float], ratio: List[float]
) -> Tuple[int, int, int, int]:
def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
......@@ -876,9 +909,7 @@ class RandomResizedCrop(torch.nn.Module):
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
......@@ -916,10 +947,10 @@ class RandomResizedCrop(torch.nn.Module):
def __repr__(self):
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
format_string = self.__class__.__name__ + "(size={0}".format(self.size)
format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
format_string += ", interpolation={0})".format(interpolate_str)
return format_string
......@@ -927,9 +958,12 @@ class RandomSizedCrop(RandomResizedCrop):
"""
Note: This transform is deprecated in favor of RandomResizedCrop.
"""
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
"please use transforms.RandomResizedCrop instead.")
warnings.warn(
"The use of the transforms.RandomSizedCrop transform is deprecated, "
+ "please use transforms.RandomResizedCrop instead."
)
super(RandomSizedCrop, self).__init__(*args, **kwargs)
......@@ -976,7 +1010,7 @@ class FiveCrop(torch.nn.Module):
return F.five_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
return self.__class__.__name__ + "(size={0})".format(self.size)
class TenCrop(torch.nn.Module):
......@@ -1025,7 +1059,7 @@ class TenCrop(torch.nn.Module):
return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
return self.__class__.__name__ + "(size={0}, vertical_flip={1})".format(self.size, self.vertical_flip)
class LinearTransformation(torch.nn.Module):
......@@ -1050,17 +1084,25 @@ class LinearTransformation(torch.nn.Module):
def __init__(self, transformation_matrix, mean_vector):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
raise ValueError(
"transformation_matrix should be square. Got "
+ "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
" as any one of the dimensions of the transformation_matrix [{}]"
.format(tuple(transformation_matrix.size())))
raise ValueError(
"mean_vector should have the same length {}".format(mean_vector.size(0))
+ " as any one of the dimensions of the transformation_matrix [{}]".format(
tuple(transformation_matrix.size())
)
)
if transformation_matrix.device != mean_vector.device:
raise ValueError("Input tensors should be on the same device. Got {} and {}"
.format(transformation_matrix.device, mean_vector.device))
raise ValueError(
"Input tensors should be on the same device. Got {} and {}".format(
transformation_matrix.device, mean_vector.device
)
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
......@@ -1076,13 +1118,17 @@ class LinearTransformation(torch.nn.Module):
shape = tensor.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError("Input tensor and transformation matrix have incompatible shape." +
"[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
"{}".format(self.transformation_matrix.shape[0]))
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1])
+ "{}".format(self.transformation_matrix.shape[0])
)
if tensor.device.type != self.mean_vector.device.type:
raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device))
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device)
)
flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
......@@ -1090,9 +1136,9 @@ class LinearTransformation(torch.nn.Module):
return tensor
def __repr__(self):
format_string = self.__class__.__name__ + '(transformation_matrix='
format_string += (str(self.transformation_matrix.tolist()) + ')')
format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
format_string = self.__class__.__name__ + "(transformation_matrix="
format_string += str(self.transformation_matrix.tolist()) + ")"
format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")"
return format_string
......@@ -1119,14 +1165,13 @@ class ColorJitter(torch.nn.Module):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False)
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name))
......@@ -1146,11 +1191,12 @@ class ColorJitter(torch.nn.Module):
return value
@staticmethod
def get_params(brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
def get_params(
brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]],
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Args:
......@@ -1184,8 +1230,9 @@ class ColorJitter(torch.nn.Module):
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue
)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
......@@ -1200,11 +1247,11 @@ class ColorJitter(torch.nn.Module):
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
format_string = self.__class__.__name__ + "("
format_string += "brightness={0}".format(self.brightness)
format_string += ", contrast={0}".format(self.contrast)
format_string += ", saturation={0}".format(self.saturation)
format_string += ", hue={0})".format(self.hue)
return format_string
......@@ -1254,10 +1301,10 @@ class RandomRotation(torch.nn.Module):
)
interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2, ))
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
......@@ -1301,14 +1348,14 @@ class RandomRotation(torch.nn.Module):
def __repr__(self):
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', interpolation={0}'.format(interpolate_str)
format_string += ', expand={0}'.format(self.expand)
format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees)
format_string += ", interpolation={0}".format(interpolate_str)
format_string += ", expand={0}".format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
format_string += ", center={0}".format(self.center)
if self.fill is not None:
format_string += ', fill={0}'.format(self.fill)
format_string += ')'
format_string += ", fill={0}".format(self.fill)
format_string += ")"
return format_string
......@@ -1349,8 +1396,15 @@ class RandomAffine(torch.nn.Module):
"""
def __init__(
self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0,
fillcolor=None, resample=None
self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation=InterpolationMode.NEAREST,
fill=0,
fillcolor=None,
resample=None,
):
super().__init__()
if resample is not None:
......@@ -1373,17 +1427,17 @@ class RandomAffine(torch.nn.Module):
)
fill = fillcolor
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2, ))
_check_sequence_input(translate, "translate", req_sizes=(2,))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError("translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2, ))
_check_sequence_input(scale, "scale", req_sizes=(2,))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
......@@ -1405,11 +1459,11 @@ class RandomAffine(torch.nn.Module):
@staticmethod
def get_params(
degrees: List[float],
translate: Optional[List[float]],
scale_ranges: Optional[List[float]],
shears: Optional[List[float]],
img_size: List[int]
degrees: List[float],
translate: Optional[List[float]],
scale_ranges: Optional[List[float]],
shears: Optional[List[float]],
img_size: List[int],
) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
"""Get parameters for affine transformation
......@@ -1462,20 +1516,20 @@ class RandomAffine(torch.nn.Module):
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
def __repr__(self):
s = '{name}(degrees={degrees}'
s = "{name}(degrees={degrees}"
if self.translate is not None:
s += ', translate={translate}'
s += ", translate={translate}"
if self.scale is not None:
s += ', scale={scale}'
s += ", scale={scale}"
if self.shear is not None:
s += ', shear={shear}'
s += ", shear={shear}"
if self.interpolation != InterpolationMode.NEAREST:
s += ', interpolation={interpolation}'
s += ", interpolation={interpolation}"
if self.fill != 0:
s += ', fill={fill}'
s += ')'
s += ", fill={fill}"
s += ")"
d = dict(self.__dict__)
d['interpolation'] = self.interpolation.value
d["interpolation"] = self.interpolation.value
return s.format(name=self.__class__.__name__, **d)
......@@ -1510,7 +1564,7 @@ class Grayscale(torch.nn.Module):
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
return self.__class__.__name__ + "(num_output_channels={0})".format(self.num_output_channels)
class RandomGrayscale(torch.nn.Module):
......@@ -1547,11 +1601,11 @@ class RandomGrayscale(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(p={0})'.format(self.p)
return self.__class__.__name__ + "(p={0})".format(self.p)
class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
"""Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
......@@ -1603,7 +1657,7 @@ class RandomErasing(torch.nn.Module):
@staticmethod
def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
) -> Tuple[int, int, int, int, Tensor]:
"""Get parameters for ``erase`` for a random erasing.
......@@ -1624,9 +1678,7 @@ class RandomErasing(torch.nn.Module):
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
......@@ -1638,8 +1690,8 @@ class RandomErasing(torch.nn.Module):
else:
v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1, )).item()
j = torch.randint(0, img_w - w + 1, size=(1, )).item()
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
return i, j, h, w, v
# Return original image
......@@ -1657,7 +1709,9 @@ class RandomErasing(torch.nn.Module):
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value, ]
value = [
self.value,
]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
......@@ -1676,11 +1730,11 @@ class RandomErasing(torch.nn.Module):
return img
def __repr__(self):
s = '(p={}, '.format(self.p)
s += 'scale={}, '.format(self.scale)
s += 'ratio={}, '.format(self.ratio)
s += 'value={}, '.format(self.value)
s += 'inplace={})'.format(self.inplace)
s = "(p={}, ".format(self.p)
s += "scale={}, ".format(self.scale)
s += "ratio={}, ".format(self.ratio)
s += "value={}, ".format(self.value)
s += "inplace={})".format(self.inplace)
return self.__class__.__name__ + s
......@@ -1713,7 +1767,7 @@ class GaussianBlur(torch.nn.Module):
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0. < sigma[0] <= sigma[1]:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise ValueError("sigma should be a single number or a list/tuple with length 2.")
......@@ -1745,8 +1799,8 @@ class GaussianBlur(torch.nn.Module):
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
def __repr__(self):
s = '(kernel_size={}, '.format(self.kernel_size)
s += 'sigma={})'.format(self.sigma)
s = "(kernel_size={}, ".format(self.kernel_size)
s += "sigma={})".format(self.sigma)
return self.__class__.__name__ + s
......@@ -1771,7 +1825,7 @@ def _check_sequence_input(x, name, req_sizes):
raise ValueError("{} should be sequence of length {}.".format(name, msg))
def _setup_angle(x, name, req_sizes=(2, )):
def _setup_angle(x, name, req_sizes=(2,)):
if isinstance(x, numbers.Number):
if x < 0:
raise ValueError("If {} is a single number, it must be positive.".format(name))
......@@ -1809,7 +1863,7 @@ class RandomInvert(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
return self.__class__.__name__ + "(p={})".format(self.p)
class RandomPosterize(torch.nn.Module):
......@@ -1841,7 +1895,7 @@ class RandomPosterize(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p)
return self.__class__.__name__ + "(bits={},p={})".format(self.bits, self.p)
class RandomSolarize(torch.nn.Module):
......@@ -1873,7 +1927,7 @@ class RandomSolarize(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)
return self.__class__.__name__ + "(threshold={},p={})".format(self.threshold, self.p)
class RandomAdjustSharpness(torch.nn.Module):
......@@ -1905,7 +1959,7 @@ class RandomAdjustSharpness(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p)
return self.__class__.__name__ + "(sharpness_factor={},p={})".format(self.sharpness_factor, self.p)
class RandomAutocontrast(torch.nn.Module):
......@@ -1935,7 +1989,7 @@ class RandomAutocontrast(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
return self.__class__.__name__ + "(p={})".format(self.p)
class RandomEqualize(torch.nn.Module):
......@@ -1965,4 +2019,4 @@ class RandomEqualize(torch.nn.Module):
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
return self.__class__.__name__ + "(p={})".format(self.p)
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import pathlib
import torch
import math
import pathlib
import warnings
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont, ImageColor
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
......@@ -18,7 +19,7 @@ def make_grid(
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
**kwargs
**kwargs,
) -> torch.Tensor:
"""
Make a grid of images.
......@@ -41,9 +42,8 @@ def make_grid(
Returns:
grid (Tensor): the tensor containing grid of images.
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead."
......@@ -67,8 +67,9 @@ def make_grid(
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if value_range is not None:
assert isinstance(value_range, tuple), \
"value_range has to be a tuple (min, max) if specified. min and max are numbers"
assert isinstance(
value_range, tuple
), "value_range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
......@@ -115,7 +116,7 @@ def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO],
format: Optional[str] = None,
**kwargs
**kwargs,
) -> None:
"""
Save a given Tensor into an image file.
......@@ -131,7 +132,7 @@ def save_image(
grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)
......@@ -145,7 +146,7 @@ def draw_bounding_boxes(
fill: Optional[bool] = False,
width: int = 1,
font: Optional[str] = None,
font_size: int = 10
font_size: int = 10,
) -> torch.Tensor:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment