Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
......@@ -9,9 +9,10 @@ is implemented
"""
import warnings
from typing import Callable, List, Optional
import torch
from torch import Tensor
from typing import Callable, List, Optional
class Conv2d(torch.nn.Conv2d):
......@@ -19,7 +20,9 @@ class Conv2d(torch.nn.Conv2d):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.Conv2d is deprecated and will be "
"removed in future versions, use torch.nn.Conv2d instead.", FutureWarning)
"removed in future versions, use torch.nn.Conv2d instead.",
FutureWarning,
)
class ConvTranspose2d(torch.nn.ConvTranspose2d):
......@@ -27,7 +30,9 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
"removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning)
"removed in future versions, use torch.nn.ConvTranspose2d instead.",
FutureWarning,
)
class BatchNorm2d(torch.nn.BatchNorm2d):
......@@ -35,7 +40,9 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.BatchNorm2d is deprecated and will be "
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
"removed in future versions, use torch.nn.BatchNorm2d instead.",
FutureWarning,
)
interpolate = torch.nn.functional.interpolate
......@@ -56,8 +63,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
):
# n=None for backward-compatibility
if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`",
DeprecationWarning)
warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning)
num_features = n
super(FrozenBatchNorm2d, self).__init__()
self.eps = eps
......@@ -76,13 +82,13 @@ class FrozenBatchNorm2d(torch.nn.Module):
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + 'num_batches_tracked'
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning
......@@ -115,8 +121,18 @@ class ConvNormActivation(torch.nn.Sequential):
) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
dilation=dilation, groups=groups, bias=norm_layer is None)]
layers = [
torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=dilation,
groups=groups,
bias=norm_layer is None,
)
]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
......
import torch
from torch import nn, Tensor
from typing import Optional, List, Dict, Tuple, Union
import torch
import torchvision
from .roi_align import roi_align
from torch import nn, Tensor
from torchvision.ops.boxes import box_area
from typing import Optional, List, Dict, Tuple, Union
from .roi_align import roi_align
# copying result_idx_in_level to a specific index in result[]
......@@ -16,15 +16,17 @@ from typing import Optional, List, Dict, Tuple, Union
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1),
first_result.size(2), first_result.size(3)),
dtype=dtype, device=device)
res = torch.zeros(
(levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
)
for level in range(len(unmerged_results)):
index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
index = index.expand(index.size(0),
unmerged_results[level].size(1),
unmerged_results[level].size(2),
unmerged_results[level].size(3))
index = index.expand(
index.size(0),
unmerged_results[level].size(1),
unmerged_results[level].size(2),
unmerged_results[level].size(3),
)
res = res.scatter(0, index, unmerged_results[level])
return res
......@@ -116,10 +118,7 @@ class MultiScaleRoIAlign(nn.Module):
"""
__annotations__ = {
'scales': Optional[List[float]],
'map_levels': Optional[LevelMapper]
}
__annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
def __init__(
self,
......@@ -224,10 +223,11 @@ class MultiScaleRoIAlign(nn.Module):
if num_levels == 1:
return roi_align(
x_filtered[0], rois,
x_filtered[0],
rois,
output_size=self.output_size,
spatial_scale=scales[0],
sampling_ratio=self.sampling_ratio
sampling_ratio=self.sampling_ratio,
)
mapper = self.map_levels
......@@ -240,7 +240,11 @@ class MultiScaleRoIAlign(nn.Module):
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(num_rois, num_channels,) + self.output_size,
(
num_rois,
num_channels,
)
+ self.output_size,
dtype=dtype,
device=device,
)
......@@ -251,9 +255,12 @@ class MultiScaleRoIAlign(nn.Module):
rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align(
per_level_feature, rois_per_level,
per_level_feature,
rois_per_level,
output_size=self.output_size,
spatial_scale=scale, sampling_ratio=self.sampling_ratio)
spatial_scale=scale,
sampling_ratio=self.sampling_ratio,
)
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
......@@ -273,5 +280,7 @@ class MultiScaleRoIAlign(nn.Module):
return result
def __repr__(self) -> str:
return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})")
return (
f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
)
import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -49,10 +48,9 @@ def ps_roi_align(
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.ps_roi_align(input, rois, spatial_scale,
output_size[0],
output_size[1],
sampling_ratio)
output, _ = torch.ops.torchvision.ps_roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio
)
return output
......@@ -60,6 +58,7 @@ class PSRoIAlign(nn.Module):
"""
See :func:`ps_roi_align`.
"""
def __init__(
self,
output_size: int,
......@@ -72,13 +71,12 @@ class PSRoIAlign(nn.Module):
self.sampling_ratio = sampling_ratio
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return ps_roi_align(input, rois, self.output_size, self.spatial_scale,
self.sampling_ratio)
return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio)
tmpstr += ')'
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ")"
return tmpstr
import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -43,9 +42,7 @@ def ps_roi_pool(
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale,
output_size[0],
output_size[1])
output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
return output
......@@ -53,6 +50,7 @@ class PSRoIPool(nn.Module):
"""
See :func:`ps_roi_pool`.
"""
def __init__(self, output_size: int, spatial_scale: float):
super(PSRoIPool, self).__init__()
self.output_size = output_size
......@@ -62,8 +60,8 @@ class PSRoIPool(nn.Module):
return ps_roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ')'
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ")"
return tmpstr
......@@ -2,8 +2,8 @@ from typing import List, Union
import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -55,15 +55,16 @@ def roi_align(
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
output_size[0], output_size[1],
sampling_ratio, aligned)
return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
)
class RoIAlign(nn.Module):
"""
See :func:`roi_align`.
"""
def __init__(
self,
output_size: BroadcastingList2[int],
......@@ -81,10 +82,10 @@ class RoIAlign(nn.Module):
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio)
tmpstr += ', aligned=' + str(self.aligned)
tmpstr += ')'
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ", aligned=" + str(self.aligned)
tmpstr += ")"
return tmpstr
......@@ -2,8 +2,8 @@ from typing import List, Union
import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -44,8 +44,7 @@ def roi_pool(
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
output_size[0], output_size[1])
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
return output
......@@ -53,6 +52,7 @@ class RoIPool(nn.Module):
"""
See :func:`roi_pool`.
"""
def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
super(RoIPool, self).__init__()
self.output_size = output_size
......@@ -62,8 +62,8 @@ class RoIPool(nn.Module):
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ')'
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ")"
return tmpstr
......@@ -38,13 +38,14 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
return input * noise
torch.fx.wrap('stochastic_depth')
torch.fx.wrap("stochastic_depth")
class StochasticDepth(nn.Module):
"""
See :func:`stochastic_depth`.
"""
def __init__(self, p: float, mode: str) -> None:
super().__init__()
self.p = p
......@@ -54,8 +55,8 @@ class StochasticDepth(nn.Module):
return stochastic_depth(input, self.p, self.mode, self.training)
def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '('
tmpstr += 'p=' + str(self.p)
tmpstr += ', mode=' + str(self.mode)
tmpstr += ')'
tmpstr = self.__class__.__name__ + "("
tmpstr += "p=" + str(self.p)
tmpstr += ", mode=" + str(self.mode)
tmpstr += ")"
return tmpstr
......@@ -8,9 +8,9 @@ except (ModuleNotFoundError, TypeError) as error:
) from error
from ._home import home
from . import decoder, utils
# Load this last, since some parts depend on the above being loaded first
from ._api import register, _list as list, info, load
from ._folder import from_data_folder, from_image_folder
from ._home import home
......@@ -3,11 +3,11 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils._internal import add_suggestion
from . import _builtin
DATASETS: Dict[str, Dataset] = {}
......@@ -18,12 +18,7 @@ def register(dataset: Dataset) -> None:
for name, obj in _builtin.__dict__.items():
if (
not name.startswith("_")
and isinstance(obj, type)
and issubclass(obj, Dataset)
and obj is not Dataset
):
if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset:
register(obj())
......
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
......@@ -13,7 +12,6 @@ from torch.utils.data.datapipes.iter import (
Shuffler,
Filter,
)
from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.utils import (
Dataset,
......
......@@ -8,7 +8,6 @@ from typing import Union, Tuple, List, Dict, Any
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
......
......@@ -2,7 +2,6 @@ import io
import PIL.Image
import torch
from torchvision.transforms.functional import pil_to_tensor
__all__ = ["pil"]
......
......@@ -19,11 +19,11 @@ from typing import (
import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets.utils._internal import (
add_suggestion,
sequence_to_str,
)
from ._resource import OnlineResource
......@@ -64,9 +64,7 @@ class DatasetConfig(Mapping):
try:
return self[name]
except KeyError as error:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
) from error
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable")
......@@ -133,9 +131,7 @@ class DatasetInfo:
@property
def default_config(self) -> DatasetConfig:
return DatasetConfig(
{name: valid_args[0] for name, valid_args in self._valid_options.items()}
)
return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()})
def make_config(self, **options: Any) -> DatasetConfig:
for name, arg in options.items():
......@@ -167,12 +163,7 @@ class DatasetInfo:
value = getattr(self, key)
if value is not None:
items.append((key, value))
items.extend(
sorted(
(key, sequence_to_str(value))
for key, value in self._valid_options.items()
)
)
items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items()))
return make_repr(type(self).__name__, items)
......@@ -214,7 +205,5 @@ class Dataset(abc.ABC):
if not config:
config = self.info.default_config
resource_dps = [
resource.to_datapipe(root) for resource in self.resources(config)
]
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
......@@ -5,13 +5,7 @@ import pathlib
from typing import Collection, Sequence, Callable, Union, Any
__all__ = [
"INFINITE_BUFFER_SIZE",
"sequence_to_str",
"add_suggestion",
"create_categories_file",
"read_mat"
]
__all__ = ["INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", "create_categories_file", "read_mat"]
# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000
......@@ -21,10 +15,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
return f"'{seq[0]}'"
return (
f"""'{"', '".join([str(item) for item in seq[:-1]])}', """
f"""{separate_last}'{seq[-1]}'."""
)
return f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ f"""{separate_last}'{seq[-1]}'."""
def add_suggestion(
......@@ -32,9 +23,7 @@ def add_suggestion(
*,
word: str,
possibilities: Collection[str],
close_match_hint: Callable[
[str], str
] = lambda close_match: f"Did you mean '{close_match}'?",
close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[
[Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
......@@ -42,17 +31,11 @@ def add_suggestion(
if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = (
close_match_hint(suggestions[0])
if suggestions
else alternative_hint(possibilities)
)
hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
return f"{msg.strip()} {hint}"
def create_categories_file(
root: Union[str, pathlib.Path], name: str, categories: Sequence[str]
) -> None:
def create_categories_file(root: Union[str, pathlib.Path], name: str, categories: Sequence[str]) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w") as fh:
fh.write("\n".join(categories) + "\n")
......@@ -61,8 +44,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try:
import scipy.io as sio
except ImportError as error:
raise ModuleNotFoundError(
"Package `scipy` is required to be installed to read .mat files."
) from error
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
return sio.loadmat(buffer, **kwargs)
......@@ -13,9 +13,7 @@ def compute_sha256(_) -> str:
class LocalResource:
def __init__(
self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None
) -> None:
def __init__(self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None) -> None:
self.path = pathlib.Path(path).expanduser().resolve()
self.file_name = self.path.name
self.sha256 = sha256 or compute_sha256(self.path)
......@@ -39,9 +37,7 @@ class OnlineResource:
# TODO: add support for mirrors
# TODO: add support for http -> https
class HttpResource(OnlineResource):
def __init__(
self, url: str, *, sha256: str, file_name: Optional[str] = None
) -> None:
def __init__(self, url: str, *, sha256: str, file_name: Optional[str] = None) -> None:
if not file_name:
file_name = os.path.basename(urlparse(url).path)
super().__init__(url, sha256=sha256, file_name=file_name)
......
import torch
import warnings
import torch
warnings.warn(
"The _functional_video module is deprecated. Please use the functional module instead."
)
warnings.warn("The _functional_video module is deprecated. Please use the functional module instead.")
def _is_tensor_video_clip(clip):
......@@ -23,14 +22,12 @@ def crop(clip, i, j, h, w):
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i:i + h, j:j + w]
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate(
clip, size=target_size, mode=interpolation_mode, align_corners=False
)
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
......
......@@ -22,9 +22,7 @@ __all__ = [
]
warnings.warn(
"The _transforms_video module is deprecated. Please use the transforms module instead."
)
warnings.warn("The _transforms_video module is deprecated. Please use the transforms module instead.")
class RandomCropVideo(RandomCrop):
......@@ -46,7 +44,7 @@ class RandomCropVideo(RandomCrop):
return F.crop(clip, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
return self.__class__.__name__ + "(size={0})".format(self.size)
class RandomResizedCropVideo(RandomResizedCrop):
......@@ -79,10 +77,9 @@ class RandomResizedCropVideo(RandomResizedCrop):
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self):
return self.__class__.__name__ + \
'(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format(
self.size, self.interpolation_mode, self.scale, self.ratio
)
return self.__class__.__name__ + "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format(
self.size, self.interpolation_mode, self.scale, self.ratio
)
class CenterCropVideo(object):
......@@ -103,7 +100,7 @@ class CenterCropVideo(object):
return F.center_crop(clip, self.crop_size)
def __repr__(self):
return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size)
return self.__class__.__name__ + "(crop_size={0})".format(self.crop_size)
class NormalizeVideo(object):
......@@ -128,8 +125,7 @@ class NormalizeVideo(object):
return F.normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format(
self.mean, self.std, self.inplace)
return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(self.mean, self.std, self.inplace)
class ToTensorVideo(object):
......
import math
import torch
from enum import Enum
from torch import Tensor
from typing import List, Tuple, Optional, Dict
import torch
from torch import Tensor
from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
def _apply_op(img: Tensor, op_name: str, magnitude: float,
interpolation: InterpolationMode, fill: Optional[List[float]]):
def _apply_op(
img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
):
if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation, fill=fill)
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation,
fill=fill,
)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation, fill=fill)
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation,
fill=fill,
)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
img = F.affine(
img,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill,
)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
img = F.affine(
img,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill,
)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
elif op_name == "Brightness":
......@@ -55,6 +84,7 @@ class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
"""
IMAGENET = "imagenet"
CIFAR10 = "cifar10"
SVHN = "svhn"
......@@ -82,7 +112,7 @@ class AutoAugment(torch.nn.Module):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.policy = policy
......@@ -91,8 +121,7 @@ class AutoAugment(torch.nn.Module):
self.policies = self._get_policies(policy)
def _get_policies(
self,
policy: AutoAugmentPolicy
self, policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
......@@ -241,7 +270,7 @@ class AutoAugment(torch.nn.Module):
return img
def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
return self.__class__.__name__ + "(policy={}, fill={})".format(self.policy, self.fill)
class RandAugment(torch.nn.Module):
......@@ -261,11 +290,16 @@ class RandAugment(torch.nn.Module):
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
"""
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
def __init__(
self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
......@@ -319,13 +353,13 @@ class RandAugment(torch.nn.Module):
return img
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_ops={num_ops}'
s += ', magnitude={magnitude}'
s += ', num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
s = self.__class__.__name__ + "("
s += "num_ops={num_ops}"
s += ", magnitude={magnitude}"
s += ", num_magnitude_bins={num_magnitude_bins}"
s += ", interpolation={interpolation}"
s += ", fill={fill}"
s += ")"
return s.format(**self.__dict__)
......@@ -343,10 +377,14 @@ class TrivialAugmentWide(torch.nn.Module):
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
"""
def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
......@@ -389,17 +427,20 @@ class TrivialAugmentWide(torch.nn.Module):
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
if magnitudes.ndim > 0 else 0.0
magnitude = (
float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
if magnitudes.ndim > 0
else 0.0
)
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
s = self.__class__.__name__ + "("
s += "num_magnitude_bins={num_magnitude_bins}"
s += ", interpolation={interpolation}"
s += ", fill={fill}"
s += ")"
return s.format(**self.__dict__)
......@@ -2,13 +2,12 @@ import math
import numbers
import warnings
from enum import Enum
from typing import List, Tuple, Any, Optional
import numpy as np
from PIL import Image
import torch
from PIL import Image
from torch import Tensor
from typing import List, Tuple, Any, Optional
try:
import accimage
......@@ -23,6 +22,7 @@ class InterpolationMode(Enum):
"""Interpolation modes
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
"""
NEAREST = "nearest"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
......@@ -110,11 +110,11 @@ def to_tensor(pic):
Returns:
Tensor: Converted image.
"""
if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError("pic should be PIL Image or ndarray. Got {}".format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim))
default_float_dtype = torch.get_default_dtype()
......@@ -136,12 +136,10 @@ def to_tensor(pic):
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
)
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == '1':
if pic.mode == "1":
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
......@@ -165,7 +163,7 @@ def pil_to_tensor(pic):
Tensor: Converted image.
"""
if not F_pil._is_pil_image(pic):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
raise TypeError("pic should be PIL Image. Got {}".format(type(pic)))
if accimage is not None and isinstance(pic, accimage.Image):
# accimage format is always uint8 internally, so always return uint8 here
......@@ -204,7 +202,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
of the integer ``dtype``.
"""
if not isinstance(image, torch.Tensor):
raise TypeError('Input img should be Tensor Image')
raise TypeError("Input img should be Tensor Image")
return F_t.convert_image_dtype(image, dtype)
......@@ -223,12 +221,12 @@ def to_pil_image(pic, mode=None):
Returns:
PIL Image: Image converted to PIL Image.
"""
if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic)))
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension()))
raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndimension()))
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
......@@ -236,11 +234,11 @@ def to_pil_image(pic, mode=None):
# check number of channels
if pic.shape[-3] > 4:
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3]))
raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-3]))
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim))
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
......@@ -248,58 +246,58 @@ def to_pil_image(pic, mode=None):
# check number of channels
if pic.shape[-1] > 4:
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1]))
raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-1]))
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != 'F':
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
'not {}'.format(type(npimg)))
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, " + "not {}".format(type(npimg)))
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = 'L'
expected_mode = "L"
elif npimg.dtype == np.int16:
expected_mode = 'I;16'
expected_mode = "I;16"
elif npimg.dtype == np.int32:
expected_mode = 'I'
expected_mode = "I"
elif npimg.dtype == np.float32:
expected_mode = 'F'
expected_mode = "F"
if mode is not None and mode != expected_mode:
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}".format(mode, np.dtype, expected_mode)
)
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
permitted_2_channel_modes = ["LA"]
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
mode = "LA"
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
mode = "RGBA"
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
mode = "RGB"
if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
raise TypeError("Input type {} is not supported".format(npimg.dtype))
return Image.fromarray(npimg, mode=mode)
......@@ -323,14 +321,16 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
Tensor: Normalized Tensor image.
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
raise TypeError("Input tensor should be a torch tensor. Got {}.".format(type(tensor)))
if not tensor.is_floating_point():
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
raise TypeError("Input tensor should be a float tensor. Got {}.".format(tensor.dtype))
if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))
raise ValueError(
"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = "
"{}.".format(tensor.size())
)
if not inplace:
tensor = tensor.clone()
......@@ -339,7 +339,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
raise ValueError("std evaluated to zero after conversion to {}, leading to division by zero.".format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
......@@ -348,8 +348,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor:
def resize(
img: Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
......@@ -408,9 +413,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias:
warnings.warn(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
......@@ -418,8 +421,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
def scale(*args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.")
return resize(*args, **kwargs)
......@@ -527,14 +529,19 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop(
img: Tensor, top: int, left: int, height: int, width: int, size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR
img: Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> Tensor:
"""Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected
......@@ -581,9 +588,7 @@ def hflip(img: Tensor) -> Tensor:
return F_t.hflip(img)
def _get_perspective_coeffs(
startpoints: List[List[int]], endpoints: List[List[int]]
) -> List[float]:
def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
......@@ -605,18 +610,18 @@ def _get_perspective_coeffs(
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver='gels').solution
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
output: List[float] = res.tolist()
return output
def perspective(
img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None
img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Tensor:
"""Perform perspective transform of the given image.
If the image is torch Tensor, it is expected
......@@ -892,7 +897,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
......@@ -942,9 +947,13 @@ def _get_inverse_affine_matrix(
def rotate(
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, center: Optional[List[int]] = None,
fill: Optional[List[float]] = None, resample: Optional[int] = None
img: Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[int]] = None,
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
) -> Tensor:
"""Rotate the image by angle.
If the image is torch Tensor, it is expected
......@@ -1016,9 +1025,15 @@ def rotate(
def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None,
resample: Optional[int] = None, fillcolor: Optional[List[float]] = None
img: Tensor,
angle: float,
translate: List[int],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
fillcolor: Optional[List[float]] = None,
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
If the image is torch Tensor, it is expected
......@@ -1065,9 +1080,7 @@ def affine(
interpolation = _interpolation_modes_from_int(interpolation)
if fillcolor is not None:
warnings.warn(
"Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
)
warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead")
fill = fillcolor
if not isinstance(angle, (int, float)):
......@@ -1168,7 +1181,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
""" Erase the input Tensor Image with given value.
"""Erase the input Tensor Image with given value.
This transform does not support PIL Image.
Args:
......@@ -1184,12 +1197,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
raise TypeError("img should be Tensor Image. Got {}".format(type(img)))
if not inplace:
img = img.clone()
img[..., i:i + h, j:j + w] = v
img[..., i : i + h, j : j + w] = v
return img
......@@ -1220,34 +1233,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size)))
raise TypeError("kernel_size should be int or a sequence of integers. Got {}".format(type(kernel_size)))
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size)))
raise ValueError("If kernel_size is a sequence its length should be 2. Got {}".format(len(kernel_size)))
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size))
raise ValueError("kernel_size should have odd and positive integers. Got {}".format(kernel_size))
if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma)))
raise TypeError("sigma should be either float or sequence of floats. Got {}".format(type(sigma)))
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma)))
raise ValueError("If sigma is a sequence, its length should be 2. Got {}".format(len(sigma)))
for s in sigma:
if s <= 0.:
raise ValueError('sigma should have positive values. Got {}'.format(sigma))
if s <= 0.0:
raise ValueError("sigma should have positive values. Got {}".format(sigma))
t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image or Tensor. Got {}".format(type(img)))
t_img = to_tensor(img)
......@@ -1290,7 +1303,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
PIL Image or Tensor: Posterized image.
"""
if not (0 <= bits <= 8):
raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits))
raise ValueError("The number if bits should be between 0 and 8. Got {}".format(bits))
if not isinstance(img, torch.Tensor):
return F_pil.posterize(img, bits)
......
......@@ -29,14 +29,14 @@ def get_image_size(img: Any) -> List[int]:
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
return 1 if img.mode == "L" else 3
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
......@@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image:
@torch.jit.unused
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM)
......@@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image:
@torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
......@@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image
@torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
......@@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
@torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
......@@ -81,25 +81,25 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image
@torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
if input_mode in {"L", "1", "I", "F"}:
return img
h, s, v = img.convert('HSV').split()
h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
h = Image.fromarray(np_h, "L")
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img
......@@ -111,14 +111,14 @@ def adjust_gamma(
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3
img = img.convert("RGB")
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
......@@ -147,8 +147,9 @@ def pad(
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))
)
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
......@@ -187,7 +188,7 @@ def pad(
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == 'P':
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
......@@ -216,7 +217,7 @@ def crop(
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.crop((left, top, left + width, top + height))
......@@ -230,9 +231,9 @@ def resize(
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))
raise TypeError("Got inappropriate size arg: {}".format(size))
if isinstance(size, Sequence) and len(size) == 1:
size = size[0]
......@@ -280,8 +281,7 @@ def _parse_fill(
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
msg = "The number of elements in 'fill' does not match the number of " "bands of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
......@@ -298,7 +298,7 @@ def affine(
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
output_size = img.size
opts = _parse_fill(fill, img)
......@@ -331,7 +331,7 @@ def perspective(
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img)
......@@ -341,17 +341,17 @@ def perspective(
@torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if num_output_channels == 1:
img = img.convert('L')
img = img.convert("L")
elif num_output_channels == 3:
img = img.convert('L')
img = img.convert("L")
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
img = Image.fromarray(np_img, "RGB")
else:
raise ValueError('num_output_channels should be either 1 or 3')
raise ValueError("num_output_channels should be either 1 or 3")
return img
......@@ -359,28 +359,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
@torch.jit.unused
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
......@@ -390,12 +390,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
@torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.equalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment