Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
...@@ -9,9 +9,10 @@ is implemented ...@@ -9,9 +9,10 @@ is implemented
""" """
import warnings import warnings
from typing import Callable, List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Callable, List, Optional
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
...@@ -19,7 +20,9 @@ class Conv2d(torch.nn.Conv2d): ...@@ -19,7 +20,9 @@ class Conv2d(torch.nn.Conv2d):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn( warnings.warn(
"torchvision.ops.misc.Conv2d is deprecated and will be " "torchvision.ops.misc.Conv2d is deprecated and will be "
"removed in future versions, use torch.nn.Conv2d instead.", FutureWarning) "removed in future versions, use torch.nn.Conv2d instead.",
FutureWarning,
)
class ConvTranspose2d(torch.nn.ConvTranspose2d): class ConvTranspose2d(torch.nn.ConvTranspose2d):
...@@ -27,7 +30,9 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d): ...@@ -27,7 +30,9 @@ class ConvTranspose2d(torch.nn.ConvTranspose2d):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn( warnings.warn(
"torchvision.ops.misc.ConvTranspose2d is deprecated and will be " "torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
"removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning) "removed in future versions, use torch.nn.ConvTranspose2d instead.",
FutureWarning,
)
class BatchNorm2d(torch.nn.BatchNorm2d): class BatchNorm2d(torch.nn.BatchNorm2d):
...@@ -35,7 +40,9 @@ class BatchNorm2d(torch.nn.BatchNorm2d): ...@@ -35,7 +40,9 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
warnings.warn( warnings.warn(
"torchvision.ops.misc.BatchNorm2d is deprecated and will be " "torchvision.ops.misc.BatchNorm2d is deprecated and will be "
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) "removed in future versions, use torch.nn.BatchNorm2d instead.",
FutureWarning,
)
interpolate = torch.nn.functional.interpolate interpolate = torch.nn.functional.interpolate
...@@ -56,8 +63,7 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -56,8 +63,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
): ):
# n=None for backward-compatibility # n=None for backward-compatibility
if n is not None: if n is not None:
warnings.warn("`n` argument is deprecated and has been renamed `num_features`", warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning)
DeprecationWarning)
num_features = n num_features = n
super(FrozenBatchNorm2d, self).__init__() super(FrozenBatchNorm2d, self).__init__()
self.eps = eps self.eps = eps
...@@ -76,13 +82,13 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -76,13 +82,13 @@ class FrozenBatchNorm2d(torch.nn.Module):
unexpected_keys: List[str], unexpected_keys: List[str],
error_msgs: List[str], error_msgs: List[str],
): ):
num_batches_tracked_key = prefix + 'num_batches_tracked' num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict: if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key] del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict( super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
missing_keys, unexpected_keys, error_msgs) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning # move reshapes to the beginning
...@@ -115,8 +121,18 @@ class ConvNormActivation(torch.nn.Sequential): ...@@ -115,8 +121,18 @@ class ConvNormActivation(torch.nn.Sequential):
) -> None: ) -> None:
if padding is None: if padding is None:
padding = (kernel_size - 1) // 2 * dilation padding = (kernel_size - 1) // 2 * dilation
layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, layers = [
dilation=dilation, groups=groups, bias=norm_layer is None)] torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=dilation,
groups=groups,
bias=norm_layer is None,
)
]
if norm_layer is not None: if norm_layer is not None:
layers.append(norm_layer(out_channels)) layers.append(norm_layer(out_channels))
if activation_layer is not None: if activation_layer is not None:
......
import torch from typing import Optional, List, Dict, Tuple, Union
from torch import nn, Tensor
import torch
import torchvision import torchvision
from .roi_align import roi_align from torch import nn, Tensor
from torchvision.ops.boxes import box_area from torchvision.ops.boxes import box_area
from typing import Optional, List, Dict, Tuple, Union from .roi_align import roi_align
# copying result_idx_in_level to a specific index in result[] # copying result_idx_in_level to a specific index in result[]
...@@ -16,15 +16,17 @@ from typing import Optional, List, Dict, Tuple, Union ...@@ -16,15 +16,17 @@ from typing import Optional, List, Dict, Tuple, Union
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor: def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
first_result = unmerged_results[0] first_result = unmerged_results[0]
dtype, device = first_result.dtype, first_result.device dtype, device = first_result.dtype, first_result.device
res = torch.zeros((levels.size(0), first_result.size(1), res = torch.zeros(
first_result.size(2), first_result.size(3)), (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
dtype=dtype, device=device) )
for level in range(len(unmerged_results)): for level in range(len(unmerged_results)):
index = torch.where(levels == level)[0].view(-1, 1, 1, 1) index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
index = index.expand(index.size(0), index = index.expand(
unmerged_results[level].size(1), index.size(0),
unmerged_results[level].size(2), unmerged_results[level].size(1),
unmerged_results[level].size(3)) unmerged_results[level].size(2),
unmerged_results[level].size(3),
)
res = res.scatter(0, index, unmerged_results[level]) res = res.scatter(0, index, unmerged_results[level])
return res return res
...@@ -116,10 +118,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -116,10 +118,7 @@ class MultiScaleRoIAlign(nn.Module):
""" """
__annotations__ = { __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
'scales': Optional[List[float]],
'map_levels': Optional[LevelMapper]
}
def __init__( def __init__(
self, self,
...@@ -224,10 +223,11 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -224,10 +223,11 @@ class MultiScaleRoIAlign(nn.Module):
if num_levels == 1: if num_levels == 1:
return roi_align( return roi_align(
x_filtered[0], rois, x_filtered[0],
rois,
output_size=self.output_size, output_size=self.output_size,
spatial_scale=scales[0], spatial_scale=scales[0],
sampling_ratio=self.sampling_ratio sampling_ratio=self.sampling_ratio,
) )
mapper = self.map_levels mapper = self.map_levels
...@@ -240,7 +240,11 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -240,7 +240,11 @@ class MultiScaleRoIAlign(nn.Module):
dtype, device = x_filtered[0].dtype, x_filtered[0].device dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros( result = torch.zeros(
(num_rois, num_channels,) + self.output_size, (
num_rois,
num_channels,
)
+ self.output_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
...@@ -251,9 +255,12 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -251,9 +255,12 @@ class MultiScaleRoIAlign(nn.Module):
rois_per_level = rois[idx_in_level] rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align( result_idx_in_level = roi_align(
per_level_feature, rois_per_level, per_level_feature,
rois_per_level,
output_size=self.output_size, output_size=self.output_size,
spatial_scale=scale, sampling_ratio=self.sampling_ratio) spatial_scale=scale,
sampling_ratio=self.sampling_ratio,
)
if torchvision._is_tracing(): if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype)) tracing_results.append(result_idx_in_level.to(dtype))
...@@ -273,5 +280,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -273,5 +280,7 @@ class MultiScaleRoIAlign(nn.Module):
return result return result
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " return (
f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})") f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
)
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
...@@ -49,10 +48,9 @@ def ps_roi_align( ...@@ -49,10 +48,9 @@ def ps_roi_align(
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.ps_roi_align(input, rois, spatial_scale, output, _ = torch.ops.torchvision.ps_roi_align(
output_size[0], input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio
output_size[1], )
sampling_ratio)
return output return output
...@@ -60,6 +58,7 @@ class PSRoIAlign(nn.Module): ...@@ -60,6 +58,7 @@ class PSRoIAlign(nn.Module):
""" """
See :func:`ps_roi_align`. See :func:`ps_roi_align`.
""" """
def __init__( def __init__(
self, self,
output_size: int, output_size: int,
...@@ -72,13 +71,12 @@ class PSRoIAlign(nn.Module): ...@@ -72,13 +71,12 @@ class PSRoIAlign(nn.Module):
self.sampling_ratio = sampling_ratio self.sampling_ratio = sampling_ratio
def forward(self, input: Tensor, rois: Tensor) -> Tensor: def forward(self, input: Tensor, rois: Tensor) -> Tensor:
return ps_roi_align(input, rois, self.output_size, self.spatial_scale, return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
self.sampling_ratio)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '(' tmpstr = self.__class__.__name__ + "("
tmpstr += 'output_size=' + str(self.output_size) tmpstr += "output_size=" + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale) tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ')' tmpstr += ")"
return tmpstr return tmpstr
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
...@@ -43,9 +42,7 @@ def ps_roi_pool( ...@@ -43,9 +42,7 @@ def ps_roi_pool(
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
output_size[0],
output_size[1])
return output return output
...@@ -53,6 +50,7 @@ class PSRoIPool(nn.Module): ...@@ -53,6 +50,7 @@ class PSRoIPool(nn.Module):
""" """
See :func:`ps_roi_pool`. See :func:`ps_roi_pool`.
""" """
def __init__(self, output_size: int, spatial_scale: float): def __init__(self, output_size: int, spatial_scale: float):
super(PSRoIPool, self).__init__() super(PSRoIPool, self).__init__()
self.output_size = output_size self.output_size = output_size
...@@ -62,8 +60,8 @@ class PSRoIPool(nn.Module): ...@@ -62,8 +60,8 @@ class PSRoIPool(nn.Module):
return ps_roi_pool(input, rois, self.output_size, self.spatial_scale) return ps_roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '(' tmpstr = self.__class__.__name__ + "("
tmpstr += 'output_size=' + str(self.output_size) tmpstr += "output_size=" + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale) tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ')' tmpstr += ")"
return tmpstr return tmpstr
...@@ -2,8 +2,8 @@ from typing import List, Union ...@@ -2,8 +2,8 @@ from typing import List, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
...@@ -55,15 +55,16 @@ def roi_align( ...@@ -55,15 +55,16 @@ def roi_align(
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
return torch.ops.torchvision.roi_align(input, rois, spatial_scale, return torch.ops.torchvision.roi_align(
output_size[0], output_size[1], input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
sampling_ratio, aligned) )
class RoIAlign(nn.Module): class RoIAlign(nn.Module):
""" """
See :func:`roi_align`. See :func:`roi_align`.
""" """
def __init__( def __init__(
self, self,
output_size: BroadcastingList2[int], output_size: BroadcastingList2[int],
...@@ -81,10 +82,10 @@ class RoIAlign(nn.Module): ...@@ -81,10 +82,10 @@ class RoIAlign(nn.Module):
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '(' tmpstr = self.__class__.__name__ + "("
tmpstr += 'output_size=' + str(self.output_size) tmpstr += "output_size=" + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale) tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ', aligned=' + str(self.aligned) tmpstr += ", aligned=" + str(self.aligned)
tmpstr += ')' tmpstr += ")"
return tmpstr return tmpstr
...@@ -2,8 +2,8 @@ from typing import List, Union ...@@ -2,8 +2,8 @@ from typing import List, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
...@@ -44,8 +44,7 @@ def roi_pool( ...@@ -44,8 +44,7 @@ def roi_pool(
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
output_size[0], output_size[1])
return output return output
...@@ -53,6 +52,7 @@ class RoIPool(nn.Module): ...@@ -53,6 +52,7 @@ class RoIPool(nn.Module):
""" """
See :func:`roi_pool`. See :func:`roi_pool`.
""" """
def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
super(RoIPool, self).__init__() super(RoIPool, self).__init__()
self.output_size = output_size self.output_size = output_size
...@@ -62,8 +62,8 @@ class RoIPool(nn.Module): ...@@ -62,8 +62,8 @@ class RoIPool(nn.Module):
return roi_pool(input, rois, self.output_size, self.spatial_scale) return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '(' tmpstr = self.__class__.__name__ + "("
tmpstr += 'output_size=' + str(self.output_size) tmpstr += "output_size=" + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale) tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ')' tmpstr += ")"
return tmpstr return tmpstr
...@@ -38,13 +38,14 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) ...@@ -38,13 +38,14 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
return input * noise return input * noise
torch.fx.wrap('stochastic_depth') torch.fx.wrap("stochastic_depth")
class StochasticDepth(nn.Module): class StochasticDepth(nn.Module):
""" """
See :func:`stochastic_depth`. See :func:`stochastic_depth`.
""" """
def __init__(self, p: float, mode: str) -> None: def __init__(self, p: float, mode: str) -> None:
super().__init__() super().__init__()
self.p = p self.p = p
...@@ -54,8 +55,8 @@ class StochasticDepth(nn.Module): ...@@ -54,8 +55,8 @@ class StochasticDepth(nn.Module):
return stochastic_depth(input, self.p, self.mode, self.training) return stochastic_depth(input, self.p, self.mode, self.training)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + '(' tmpstr = self.__class__.__name__ + "("
tmpstr += 'p=' + str(self.p) tmpstr += "p=" + str(self.p)
tmpstr += ', mode=' + str(self.mode) tmpstr += ", mode=" + str(self.mode)
tmpstr += ')' tmpstr += ")"
return tmpstr return tmpstr
...@@ -8,9 +8,9 @@ except (ModuleNotFoundError, TypeError) as error: ...@@ -8,9 +8,9 @@ except (ModuleNotFoundError, TypeError) as error:
) from error ) from error
from ._home import home
from . import decoder, utils from . import decoder, utils
# Load this last, since some parts depend on the above being loaded first # Load this last, since some parts depend on the above being loaded first
from ._api import register, _list as list, info, load from ._api import register, _list as list, info, load
from ._folder import from_data_folder, from_image_folder from ._folder import from_data_folder, from_image_folder
from ._home import home
...@@ -3,11 +3,11 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -3,11 +3,11 @@ from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils._internal import add_suggestion from torchvision.prototype.datasets.utils._internal import add_suggestion
from . import _builtin from . import _builtin
DATASETS: Dict[str, Dataset] = {} DATASETS: Dict[str, Dataset] = {}
...@@ -18,12 +18,7 @@ def register(dataset: Dataset) -> None: ...@@ -18,12 +18,7 @@ def register(dataset: Dataset) -> None:
for name, obj in _builtin.__dict__.items(): for name, obj in _builtin.__dict__.items():
if ( if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset:
not name.startswith("_")
and isinstance(obj, type)
and issubclass(obj, Dataset)
and obj is not Dataset
):
register(obj()) register(obj())
......
import io import io
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import re import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import ( from torch.utils.data.datapipes.iter import (
...@@ -13,7 +12,6 @@ from torch.utils.data.datapipes.iter import ( ...@@ -13,7 +12,6 @@ from torch.utils.data.datapipes.iter import (
Shuffler, Shuffler,
Filter, Filter,
) )
from torchdata.datapipes.iter import KeyZipper from torchdata.datapipes.iter import KeyZipper
from torchvision.prototype.datasets.utils import ( from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
......
...@@ -8,7 +8,6 @@ from typing import Union, Tuple, List, Dict, Any ...@@ -8,7 +8,6 @@ from typing import Union, Tuple, List, Dict, Any
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter
from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.decoder import pil
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
......
...@@ -2,7 +2,6 @@ import io ...@@ -2,7 +2,6 @@ import io
import PIL.Image import PIL.Image
import torch import torch
from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.functional import pil_to_tensor
__all__ = ["pil"] __all__ = ["pil"]
......
...@@ -19,11 +19,11 @@ from typing import ( ...@@ -19,11 +19,11 @@ from typing import (
import torch import torch
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
add_suggestion, add_suggestion,
sequence_to_str, sequence_to_str,
) )
from ._resource import OnlineResource from ._resource import OnlineResource
...@@ -64,9 +64,7 @@ class DatasetConfig(Mapping): ...@@ -64,9 +64,7 @@ class DatasetConfig(Mapping):
try: try:
return self[name] return self[name]
except KeyError as error: except KeyError as error:
raise AttributeError( raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error
f"'{type(self).__name__}' object has no attribute '{name}'"
) from error
def __setitem__(self, key: Any, value: Any) -> NoReturn: def __setitem__(self, key: Any, value: Any) -> NoReturn:
raise RuntimeError(f"'{type(self).__name__}' object is immutable") raise RuntimeError(f"'{type(self).__name__}' object is immutable")
...@@ -133,9 +131,7 @@ class DatasetInfo: ...@@ -133,9 +131,7 @@ class DatasetInfo:
@property @property
def default_config(self) -> DatasetConfig: def default_config(self) -> DatasetConfig:
return DatasetConfig( return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()})
{name: valid_args[0] for name, valid_args in self._valid_options.items()}
)
def make_config(self, **options: Any) -> DatasetConfig: def make_config(self, **options: Any) -> DatasetConfig:
for name, arg in options.items(): for name, arg in options.items():
...@@ -167,12 +163,7 @@ class DatasetInfo: ...@@ -167,12 +163,7 @@ class DatasetInfo:
value = getattr(self, key) value = getattr(self, key)
if value is not None: if value is not None:
items.append((key, value)) items.append((key, value))
items.extend( items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items()))
sorted(
(key, sequence_to_str(value))
for key, value in self._valid_options.items()
)
)
return make_repr(type(self).__name__, items) return make_repr(type(self).__name__, items)
...@@ -214,7 +205,5 @@ class Dataset(abc.ABC): ...@@ -214,7 +205,5 @@ class Dataset(abc.ABC):
if not config: if not config:
config = self.info.default_config config = self.info.default_config
resource_dps = [ resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
resource.to_datapipe(root) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
...@@ -5,13 +5,7 @@ import pathlib ...@@ -5,13 +5,7 @@ import pathlib
from typing import Collection, Sequence, Callable, Union, Any from typing import Collection, Sequence, Callable, Union, Any
__all__ = [ __all__ = ["INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", "create_categories_file", "read_mat"]
"INFINITE_BUFFER_SIZE",
"sequence_to_str",
"add_suggestion",
"create_categories_file",
"read_mat"
]
# pseudo-infinite until a true infinite buffer is supported by all datapipes # pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000 INFINITE_BUFFER_SIZE = 1_000_000_000
...@@ -21,10 +15,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: ...@@ -21,10 +15,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1: if len(seq) == 1:
return f"'{seq[0]}'" return f"'{seq[0]}'"
return ( return f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ f"""{separate_last}'{seq[-1]}'."""
f"""'{"', '".join([str(item) for item in seq[:-1]])}', """
f"""{separate_last}'{seq[-1]}'."""
)
def add_suggestion( def add_suggestion(
...@@ -32,9 +23,7 @@ def add_suggestion( ...@@ -32,9 +23,7 @@ def add_suggestion(
*, *,
word: str, word: str,
possibilities: Collection[str], possibilities: Collection[str],
close_match_hint: Callable[ close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?",
[str], str
] = lambda close_match: f"Did you mean '{close_match}'?",
alternative_hint: Callable[ alternative_hint: Callable[
[Sequence[str]], str [Sequence[str]], str
] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.", ] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.",
...@@ -42,17 +31,11 @@ def add_suggestion( ...@@ -42,17 +31,11 @@ def add_suggestion(
if not isinstance(possibilities, collections.abc.Sequence): if not isinstance(possibilities, collections.abc.Sequence):
possibilities = sorted(possibilities) possibilities = sorted(possibilities)
suggestions = difflib.get_close_matches(word, possibilities, 1) suggestions = difflib.get_close_matches(word, possibilities, 1)
hint = ( hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities)
close_match_hint(suggestions[0])
if suggestions
else alternative_hint(possibilities)
)
return f"{msg.strip()} {hint}" return f"{msg.strip()} {hint}"
def create_categories_file( def create_categories_file(root: Union[str, pathlib.Path], name: str, categories: Sequence[str]) -> None:
root: Union[str, pathlib.Path], name: str, categories: Sequence[str]
) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w") as fh: with open(pathlib.Path(root) / f"{name}.categories", "w") as fh:
fh.write("\n".join(categories) + "\n") fh.write("\n".join(categories) + "\n")
...@@ -61,8 +44,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: ...@@ -61,8 +44,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try: try:
import scipy.io as sio import scipy.io as sio
except ImportError as error: except ImportError as error:
raise ModuleNotFoundError( raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
"Package `scipy` is required to be installed to read .mat files."
) from error
return sio.loadmat(buffer, **kwargs) return sio.loadmat(buffer, **kwargs)
...@@ -13,9 +13,7 @@ def compute_sha256(_) -> str: ...@@ -13,9 +13,7 @@ def compute_sha256(_) -> str:
class LocalResource: class LocalResource:
def __init__( def __init__(self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None) -> None:
self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None
) -> None:
self.path = pathlib.Path(path).expanduser().resolve() self.path = pathlib.Path(path).expanduser().resolve()
self.file_name = self.path.name self.file_name = self.path.name
self.sha256 = sha256 or compute_sha256(self.path) self.sha256 = sha256 or compute_sha256(self.path)
...@@ -39,9 +37,7 @@ class OnlineResource: ...@@ -39,9 +37,7 @@ class OnlineResource:
# TODO: add support for mirrors # TODO: add support for mirrors
# TODO: add support for http -> https # TODO: add support for http -> https
class HttpResource(OnlineResource): class HttpResource(OnlineResource):
def __init__( def __init__(self, url: str, *, sha256: str, file_name: Optional[str] = None) -> None:
self, url: str, *, sha256: str, file_name: Optional[str] = None
) -> None:
if not file_name: if not file_name:
file_name = os.path.basename(urlparse(url).path) file_name = os.path.basename(urlparse(url).path)
super().__init__(url, sha256=sha256, file_name=file_name) super().__init__(url, sha256=sha256, file_name=file_name)
......
import torch
import warnings import warnings
import torch
warnings.warn( warnings.warn("The _functional_video module is deprecated. Please use the functional module instead.")
"The _functional_video module is deprecated. Please use the functional module instead."
)
def _is_tensor_video_clip(clip): def _is_tensor_video_clip(clip):
...@@ -23,14 +22,12 @@ def crop(clip, i, j, h, w): ...@@ -23,14 +22,12 @@ def crop(clip, i, j, h, w):
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
""" """
assert len(clip.size()) == 4, "clip should be a 4D tensor" assert len(clip.size()) == 4, "clip should be a 4D tensor"
return clip[..., i:i + h, j:j + w] return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode): def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)" assert len(target_size) == 2, "target size should be tuple (height, width)"
return torch.nn.functional.interpolate( return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
clip, size=target_size, mode=interpolation_mode, align_corners=False
)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
......
...@@ -22,9 +22,7 @@ __all__ = [ ...@@ -22,9 +22,7 @@ __all__ = [
] ]
warnings.warn( warnings.warn("The _transforms_video module is deprecated. Please use the transforms module instead.")
"The _transforms_video module is deprecated. Please use the transforms module instead."
)
class RandomCropVideo(RandomCrop): class RandomCropVideo(RandomCrop):
...@@ -46,7 +44,7 @@ class RandomCropVideo(RandomCrop): ...@@ -46,7 +44,7 @@ class RandomCropVideo(RandomCrop):
return F.crop(clip, i, j, h, w) return F.crop(clip, i, j, h, w)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) return self.__class__.__name__ + "(size={0})".format(self.size)
class RandomResizedCropVideo(RandomResizedCrop): class RandomResizedCropVideo(RandomResizedCrop):
...@@ -79,10 +77,9 @@ class RandomResizedCropVideo(RandomResizedCrop): ...@@ -79,10 +77,9 @@ class RandomResizedCropVideo(RandomResizedCrop):
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + \ return self.__class__.__name__ + "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format(
'(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format( self.size, self.interpolation_mode, self.scale, self.ratio
self.size, self.interpolation_mode, self.scale, self.ratio )
)
class CenterCropVideo(object): class CenterCropVideo(object):
...@@ -103,7 +100,7 @@ class CenterCropVideo(object): ...@@ -103,7 +100,7 @@ class CenterCropVideo(object):
return F.center_crop(clip, self.crop_size) return F.center_crop(clip, self.crop_size)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size) return self.__class__.__name__ + "(crop_size={0})".format(self.crop_size)
class NormalizeVideo(object): class NormalizeVideo(object):
...@@ -128,8 +125,7 @@ class NormalizeVideo(object): ...@@ -128,8 +125,7 @@ class NormalizeVideo(object):
return F.normalize(clip, self.mean, self.std, self.inplace) return F.normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format( return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(self.mean, self.std, self.inplace)
self.mean, self.std, self.inplace)
class ToTensorVideo(object): class ToTensorVideo(object):
......
import math import math
import torch
from enum import Enum from enum import Enum
from torch import Tensor
from typing import List, Tuple, Optional, Dict from typing import List, Tuple, Optional, Dict
import torch
from torch import Tensor
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
def _apply_op(img: Tensor, op_name: str, magnitude: float, def _apply_op(
interpolation: InterpolationMode, fill: Optional[List[float]]): img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
):
if op_name == "ShearX": if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], img = F.affine(
interpolation=interpolation, fill=fill) img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation,
fill=fill,
)
elif op_name == "ShearY": elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], img = F.affine(
interpolation=interpolation, fill=fill) img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation,
fill=fill,
)
elif op_name == "TranslateX": elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, img = F.affine(
interpolation=interpolation, shear=[0.0, 0.0], fill=fill) img,
angle=0.0,
translate=[int(magnitude), 0],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill,
)
elif op_name == "TranslateY": elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, img = F.affine(
interpolation=interpolation, shear=[0.0, 0.0], fill=fill) img,
angle=0.0,
translate=[0, int(magnitude)],
scale=1.0,
interpolation=interpolation,
shear=[0.0, 0.0],
fill=fill,
)
elif op_name == "Rotate": elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
elif op_name == "Brightness": elif op_name == "Brightness":
...@@ -55,6 +84,7 @@ class AutoAugmentPolicy(Enum): ...@@ -55,6 +84,7 @@ class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets. """AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN. Available policies are IMAGENET, CIFAR10 and SVHN.
""" """
IMAGENET = "imagenet" IMAGENET = "imagenet"
CIFAR10 = "cifar10" CIFAR10 = "cifar10"
SVHN = "svhn" SVHN = "svhn"
...@@ -82,7 +112,7 @@ class AutoAugment(torch.nn.Module): ...@@ -82,7 +112,7 @@ class AutoAugment(torch.nn.Module):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None fill: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.policy = policy self.policy = policy
...@@ -91,8 +121,7 @@ class AutoAugment(torch.nn.Module): ...@@ -91,8 +121,7 @@ class AutoAugment(torch.nn.Module):
self.policies = self._get_policies(policy) self.policies = self._get_policies(policy)
def _get_policies( def _get_policies(
self, self, policy: AutoAugmentPolicy
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET: if policy == AutoAugmentPolicy.IMAGENET:
return [ return [
...@@ -241,7 +270,7 @@ class AutoAugment(torch.nn.Module): ...@@ -241,7 +270,7 @@ class AutoAugment(torch.nn.Module):
return img return img
def __repr__(self) -> str: def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) return self.__class__.__name__ + "(policy={}, fill={})".format(self.policy, self.fill)
class RandAugment(torch.nn.Module): class RandAugment(torch.nn.Module):
...@@ -261,11 +290,16 @@ class RandAugment(torch.nn.Module): ...@@ -261,11 +290,16 @@ class RandAugment(torch.nn.Module):
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively. image. If given a number, the value is used for all bands respectively.
""" """
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, def __init__(
interpolation: InterpolationMode = InterpolationMode.NEAREST, self,
fill: Optional[List[float]] = None) -> None: num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__() super().__init__()
self.num_ops = num_ops self.num_ops = num_ops
self.magnitude = magnitude self.magnitude = magnitude
...@@ -319,13 +353,13 @@ class RandAugment(torch.nn.Module): ...@@ -319,13 +353,13 @@ class RandAugment(torch.nn.Module):
return img return img
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += 'num_ops={num_ops}' s += "num_ops={num_ops}"
s += ', magnitude={magnitude}' s += ", magnitude={magnitude}"
s += ', num_magnitude_bins={num_magnitude_bins}' s += ", num_magnitude_bins={num_magnitude_bins}"
s += ', interpolation={interpolation}' s += ", interpolation={interpolation}"
s += ', fill={fill}' s += ", fill={fill}"
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
...@@ -343,10 +377,14 @@ class TrivialAugmentWide(torch.nn.Module): ...@@ -343,10 +377,14 @@ class TrivialAugmentWide(torch.nn.Module):
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively. image. If given a number, the value is used for all bands respectively.
""" """
def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, def __init__(
fill: Optional[List[float]] = None) -> None: self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
) -> None:
super().__init__() super().__init__()
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation self.interpolation = interpolation
...@@ -389,17 +427,20 @@ class TrivialAugmentWide(torch.nn.Module): ...@@ -389,17 +427,20 @@ class TrivialAugmentWide(torch.nn.Module):
op_index = int(torch.randint(len(op_meta), (1,)).item()) op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index] op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name] magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ magnitude = (
if magnitudes.ndim > 0 else 0.0 float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
if magnitudes.ndim > 0
else 0.0
)
if signed and torch.randint(2, (1,)): if signed and torch.randint(2, (1,)):
magnitude *= -1.0 magnitude *= -1.0
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += 'num_magnitude_bins={num_magnitude_bins}' s += "num_magnitude_bins={num_magnitude_bins}"
s += ', interpolation={interpolation}' s += ", interpolation={interpolation}"
s += ', fill={fill}' s += ", fill={fill}"
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
...@@ -2,13 +2,12 @@ import math ...@@ -2,13 +2,12 @@ import math
import numbers import numbers
import warnings import warnings
from enum import Enum from enum import Enum
from typing import List, Tuple, Any, Optional
import numpy as np import numpy as np
from PIL import Image
import torch import torch
from PIL import Image
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Any, Optional
try: try:
import accimage import accimage
...@@ -23,6 +22,7 @@ class InterpolationMode(Enum): ...@@ -23,6 +22,7 @@ class InterpolationMode(Enum):
"""Interpolation modes """Interpolation modes
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
""" """
NEAREST = "nearest" NEAREST = "nearest"
BILINEAR = "bilinear" BILINEAR = "bilinear"
BICUBIC = "bicubic" BICUBIC = "bicubic"
...@@ -110,11 +110,11 @@ def to_tensor(pic): ...@@ -110,11 +110,11 @@ def to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) raise TypeError("pic should be PIL Image or ndarray. Got {}".format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic): if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim))
default_float_dtype = torch.get_default_dtype() default_float_dtype = torch.get_default_dtype()
...@@ -136,12 +136,10 @@ def to_tensor(pic): ...@@ -136,12 +136,10 @@ def to_tensor(pic):
return torch.from_numpy(nppic).to(dtype=default_float_dtype) return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image # handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32} mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
img = torch.from_numpy( img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
)
if pic.mode == '1': if pic.mode == "1":
img = 255 * img img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format # put it from HWC to CHW format
...@@ -165,7 +163,7 @@ def pil_to_tensor(pic): ...@@ -165,7 +163,7 @@ def pil_to_tensor(pic):
Tensor: Converted image. Tensor: Converted image.
""" """
if not F_pil._is_pil_image(pic): if not F_pil._is_pil_image(pic):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) raise TypeError("pic should be PIL Image. Got {}".format(type(pic)))
if accimage is not None and isinstance(pic, accimage.Image): if accimage is not None and isinstance(pic, accimage.Image):
# accimage format is always uint8 internally, so always return uint8 here # accimage format is always uint8 internally, so always return uint8 here
...@@ -204,7 +202,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -204,7 +202,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
of the integer ``dtype``. of the integer ``dtype``.
""" """
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError('Input img should be Tensor Image') raise TypeError("Input img should be Tensor Image")
return F_t.convert_image_dtype(image, dtype) return F_t.convert_image_dtype(image, dtype)
...@@ -223,12 +221,12 @@ def to_pil_image(pic, mode=None): ...@@ -223,12 +221,12 @@ def to_pil_image(pic, mode=None):
Returns: Returns:
PIL Image: Image converted to PIL Image. PIL Image: Image converted to PIL Image.
""" """
if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic)))
elif isinstance(pic, torch.Tensor): elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}: if pic.ndimension() not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndimension()))
elif pic.ndimension() == 2: elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW) # if 2D image, add channel dimension (CHW)
...@@ -236,11 +234,11 @@ def to_pil_image(pic, mode=None): ...@@ -236,11 +234,11 @@ def to_pil_image(pic, mode=None):
# check number of channels # check number of channels
if pic.shape[-3] > 4: if pic.shape[-3] > 4:
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3])) raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-3]))
elif isinstance(pic, np.ndarray): elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}: if pic.ndim not in {2, 3}:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim))
elif pic.ndim == 2: elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC) # if 2D image, add channel dimension (HWC)
...@@ -248,58 +246,58 @@ def to_pil_image(pic, mode=None): ...@@ -248,58 +246,58 @@ def to_pil_image(pic, mode=None):
# check number of channels # check number of channels
if pic.shape[-1] > 4: if pic.shape[-1] > 4:
raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1])) raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-1]))
npimg = pic npimg = pic
if isinstance(pic, torch.Tensor): if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != 'F': if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte() pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray): if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, " + "not {}".format(type(npimg)))
'not {}'.format(type(npimg)))
if npimg.shape[2] == 1: if npimg.shape[2] == 1:
expected_mode = None expected_mode = None
npimg = npimg[:, :, 0] npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8: if npimg.dtype == np.uint8:
expected_mode = 'L' expected_mode = "L"
elif npimg.dtype == np.int16: elif npimg.dtype == np.int16:
expected_mode = 'I;16' expected_mode = "I;16"
elif npimg.dtype == np.int32: elif npimg.dtype == np.int32:
expected_mode = 'I' expected_mode = "I"
elif npimg.dtype == np.float32: elif npimg.dtype == np.float32:
expected_mode = 'F' expected_mode = "F"
if mode is not None and mode != expected_mode: if mode is not None and mode != expected_mode:
raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" raise ValueError(
.format(mode, np.dtype, expected_mode)) "Incorrect mode ({}) supplied for input type {}. Should be {}".format(mode, np.dtype, expected_mode)
)
mode = expected_mode mode = expected_mode
elif npimg.shape[2] == 2: elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA'] permitted_2_channel_modes = ["LA"]
if mode is not None and mode not in permitted_2_channel_modes: if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes)) raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8: if mode is None and npimg.dtype == np.uint8:
mode = 'LA' mode = "LA"
elif npimg.shape[2] == 4: elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"]
if mode is not None and mode not in permitted_4_channel_modes: if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8: if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA' mode = "RGBA"
else: else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"]
if mode is not None and mode not in permitted_3_channel_modes: if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8: if mode is None and npimg.dtype == np.uint8:
mode = 'RGB' mode = "RGB"
if mode is None: if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype)) raise TypeError("Input type {} is not supported".format(npimg.dtype))
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=mode)
...@@ -323,14 +321,16 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -323,14 +321,16 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
Tensor: Normalized Tensor image. Tensor: Normalized Tensor image.
""" """
if not isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) raise TypeError("Input tensor should be a torch tensor. Got {}.".format(type(tensor)))
if not tensor.is_floating_point(): if not tensor.is_floating_point():
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype)) raise TypeError("Input tensor should be a float tensor. Got {}.".format(tensor.dtype))
if tensor.ndim < 3: if tensor.ndim < 3:
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' raise ValueError(
'{}.'.format(tensor.size())) "Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = "
"{}.".format(tensor.size())
)
if not inplace: if not inplace:
tensor = tensor.clone() tensor = tensor.clone()
...@@ -339,7 +339,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -339,7 +339,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any(): if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) raise ValueError("std evaluated to zero after conversion to {}, leading to division by zero.".format(dtype))
if mean.ndim == 1: if mean.ndim == 1:
mean = mean.view(-1, 1, 1) mean = mean.view(-1, 1, 1)
if std.ndim == 1: if std.ndim == 1:
...@@ -348,8 +348,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -348,8 +348,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor return tensor
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, def resize(
max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor: img: Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> Tensor:
r"""Resize the input image to the given size. r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
...@@ -408,9 +413,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte ...@@ -408,9 +413,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias: if antialias is not None and not antialias:
warnings.warn( warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
...@@ -418,8 +421,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte ...@@ -418,8 +421,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
def scale(*args, **kwargs): def scale(*args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " + warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.")
"please use transforms.Resize instead.")
return resize(*args, **kwargs) return resize(*args, **kwargs)
...@@ -527,14 +529,19 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -527,14 +529,19 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return img
crop_top = int(round((image_height - crop_height) / 2.)) crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.)) crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, crop_top, crop_left, crop_height, crop_width) return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop( def resized_crop(
img: Tensor, top: int, left: int, height: int, width: int, size: List[int], img: Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> Tensor: ) -> Tensor:
"""Crop the given image and resize it to desired size. """Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -581,9 +588,7 @@ def hflip(img: Tensor) -> Tensor: ...@@ -581,9 +588,7 @@ def hflip(img: Tensor) -> Tensor:
return F_t.hflip(img) return F_t.hflip(img)
def _get_perspective_coeffs( def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
startpoints: List[List[int]], endpoints: List[List[int]]
) -> List[float]:
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
In Perspective Transform each pixel (x, y) in the original image gets transformed as, In Perspective Transform each pixel (x, y) in the original image gets transformed as,
...@@ -605,18 +610,18 @@ def _get_perspective_coeffs( ...@@ -605,18 +610,18 @@ def _get_perspective_coeffs(
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8) b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver='gels').solution res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
output: List[float] = res.tolist() output: List[float] = res.tolist()
return output return output
def perspective( def perspective(
img: Tensor, img: Tensor,
startpoints: List[List[int]], startpoints: List[List[int]],
endpoints: List[List[int]], endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None fill: Optional[List[float]] = None,
) -> Tensor: ) -> Tensor:
"""Perform perspective transform of the given image. """Perform perspective transform of the given image.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -892,7 +897,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: ...@@ -892,7 +897,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def _get_inverse_affine_matrix( def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
) -> List[float]: ) -> List[float]:
# Helper method to compute inverse matrix for affine transformation # Helper method to compute inverse matrix for affine transformation
...@@ -942,9 +947,13 @@ def _get_inverse_affine_matrix( ...@@ -942,9 +947,13 @@ def _get_inverse_affine_matrix(
def rotate( def rotate(
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, img: Tensor,
expand: bool = False, center: Optional[List[int]] = None, angle: float,
fill: Optional[List[float]] = None, resample: Optional[int] = None interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[int]] = None,
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
) -> Tensor: ) -> Tensor:
"""Rotate the image by angle. """Rotate the image by angle.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -1016,9 +1025,15 @@ def rotate( ...@@ -1016,9 +1025,15 @@ def rotate(
def affine( def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], img: Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, angle: float,
resample: Optional[int] = None, fillcolor: Optional[List[float]] = None translate: List[int],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
fillcolor: Optional[List[float]] = None,
) -> Tensor: ) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant. """Apply affine transformation on the image keeping image center invariant.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -1065,9 +1080,7 @@ def affine( ...@@ -1065,9 +1080,7 @@ def affine(
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
if fillcolor is not None: if fillcolor is not None:
warnings.warn( warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead")
"Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
)
fill = fillcolor fill = fillcolor
if not isinstance(angle, (int, float)): if not isinstance(angle, (int, float)):
...@@ -1168,7 +1181,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: ...@@ -1168,7 +1181,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
""" Erase the input Tensor Image with given value. """Erase the input Tensor Image with given value.
This transform does not support PIL Image. This transform does not support PIL Image.
Args: Args:
...@@ -1184,12 +1197,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool ...@@ -1184,12 +1197,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
Tensor Image: Erased image. Tensor Image: Erased image.
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) raise TypeError("img should be Tensor Image. Got {}".format(type(img)))
if not inplace: if not inplace:
img = img.clone() img = img.clone()
img[..., i:i + h, j:j + w] = v img[..., i : i + h, j : j + w] = v
return img return img
...@@ -1220,34 +1233,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa ...@@ -1220,34 +1233,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
PIL Image or Tensor: Gaussian Blurred version of the image. PIL Image or Tensor: Gaussian Blurred version of the image.
""" """
if not isinstance(kernel_size, (int, list, tuple)): if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) raise TypeError("kernel_size should be int or a sequence of integers. Got {}".format(type(kernel_size)))
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size] kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2: if len(kernel_size) != 2:
raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) raise ValueError("If kernel_size is a sequence its length should be 2. Got {}".format(len(kernel_size)))
for ksize in kernel_size: for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0: if ksize % 2 == 0 or ksize < 0:
raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) raise ValueError("kernel_size should have odd and positive integers. Got {}".format(kernel_size))
if sigma is None: if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) raise TypeError("sigma should be either float or sequence of floats. Got {}".format(type(sigma)))
if isinstance(sigma, (int, float)): if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)] sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1: if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]] sigma = [sigma[0], sigma[0]]
if len(sigma) != 2: if len(sigma) != 2:
raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) raise ValueError("If sigma is a sequence, its length should be 2. Got {}".format(len(sigma)))
for s in sigma: for s in sigma:
if s <= 0.: if s <= 0.0:
raise ValueError('sigma should have positive values. Got {}'.format(sigma)) raise ValueError("sigma should have positive values. Got {}".format(sigma))
t_img = img t_img = img
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) raise TypeError("img should be PIL Image or Tensor. Got {}".format(type(img)))
t_img = to_tensor(img) t_img = to_tensor(img)
...@@ -1290,7 +1303,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: ...@@ -1290,7 +1303,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
PIL Image or Tensor: Posterized image. PIL Image or Tensor: Posterized image.
""" """
if not (0 <= bits <= 8): if not (0 <= bits <= 8):
raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) raise ValueError("The number if bits should be between 0 and 8. Got {}".format(bits))
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.posterize(img, bits) return F_pil.posterize(img, bits)
......
...@@ -29,14 +29,14 @@ def get_image_size(img: Any) -> List[int]: ...@@ -29,14 +29,14 @@ def get_image_size(img: Any) -> List[int]:
@torch.jit.unused @torch.jit.unused
def get_image_num_channels(img: Any) -> int: def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img): if _is_pil_image(img):
return 1 if img.mode == 'L' else 3 return 1 if img.mode == "L" else 3
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused @torch.jit.unused
def hflip(img: Image.Image) -> Image.Image: def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT) return img.transpose(Image.FLIP_LEFT_RIGHT)
...@@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image: ...@@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image:
@torch.jit.unused @torch.jit.unused
def vflip(img: Image.Image) -> Image.Image: def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM) return img.transpose(Image.FLIP_TOP_BOTTOM)
...@@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image: ...@@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image:
@torch.jit.unused @torch.jit.unused
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image: def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Brightness(img) enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor) img = enhancer.enhance(brightness_factor)
...@@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image ...@@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image
@torch.jit.unused @torch.jit.unused
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Contrast(img) enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor) img = enhancer.enhance(contrast_factor)
...@@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: ...@@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
@torch.jit.unused @torch.jit.unused
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image: def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Color(img) enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor) img = enhancer.enhance(saturation_factor)
...@@ -81,25 +81,25 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image ...@@ -81,25 +81,25 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image
@torch.jit.unused @torch.jit.unused
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image: def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not(-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor))
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
input_mode = img.mode input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}: if input_mode in {"L", "1", "I", "F"}:
return img return img
h, s, v = img.convert('HSV').split() h, s, v = img.convert("HSV").split()
np_h = np.array(h, dtype=np.uint8) np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries # uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'): with np.errstate(over="ignore"):
np_h += np.uint8(hue_factor * 255) np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L') h = Image.fromarray(np_h, "L")
img = Image.merge('HSV', (h, s, v)).convert(input_mode) img = Image.merge("HSV", (h, s, v)).convert(input_mode)
return img return img
...@@ -111,14 +111,14 @@ def adjust_gamma( ...@@ -111,14 +111,14 @@ def adjust_gamma(
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if gamma < 0: if gamma < 0:
raise ValueError('Gamma should be a non-negative real number') raise ValueError("Gamma should be a non-negative real number")
input_mode = img.mode input_mode = img.mode
img = img.convert('RGB') img = img.convert("RGB")
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma) for ele in range(256)] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode) img = img.convert(input_mode)
...@@ -147,8 +147,9 @@ def pad( ...@@ -147,8 +147,9 @@ def pad(
padding = tuple(padding) padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + raise ValueError(
"{} element tuple".format(len(padding))) "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))
)
if isinstance(padding, tuple) and len(padding) == 1: if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad` # Compatibility with `functional_tensor.pad`
...@@ -187,7 +188,7 @@ def pad( ...@@ -187,7 +188,7 @@ def pad(
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == 'P': if img.mode == "P":
palette = img.getpalette() palette = img.getpalette()
img = np.asarray(img) img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
...@@ -216,7 +217,7 @@ def crop( ...@@ -216,7 +217,7 @@ def crop(
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.crop((left, top, left + width, top + height)) return img.crop((left, top, left + width, top + height))
...@@ -230,9 +231,9 @@ def resize( ...@@ -230,9 +231,9 @@ def resize(
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size)) raise TypeError("Got inappropriate size arg: {}".format(size))
if isinstance(size, Sequence) and len(size) == 1: if isinstance(size, Sequence) and len(size) == 1:
size = size[0] size = size[0]
...@@ -280,8 +281,7 @@ def _parse_fill( ...@@ -280,8 +281,7 @@ def _parse_fill(
fill = tuple([fill] * num_bands) fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)): if isinstance(fill, (list, tuple)):
if len(fill) != num_bands: if len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of " msg = "The number of elements in 'fill' does not match the number of " "bands of the image ({} != {})"
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands)) raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill) fill = tuple(fill)
...@@ -298,7 +298,7 @@ def affine( ...@@ -298,7 +298,7 @@ def affine(
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
output_size = img.size output_size = img.size
opts = _parse_fill(fill, img) opts = _parse_fill(fill, img)
...@@ -331,7 +331,7 @@ def perspective( ...@@ -331,7 +331,7 @@ def perspective(
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img) opts = _parse_fill(fill, img)
...@@ -341,17 +341,17 @@ def perspective( ...@@ -341,17 +341,17 @@ def perspective(
@torch.jit.unused @torch.jit.unused
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if num_output_channels == 1: if num_output_channels == 1:
img = img.convert('L') img = img.convert("L")
elif num_output_channels == 3: elif num_output_channels == 3:
img = img.convert('L') img = img.convert("L")
np_img = np.array(img, dtype=np.uint8) np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img]) np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB') img = Image.fromarray(np_img, "RGB")
else: else:
raise ValueError('num_output_channels should be either 1 or 3') raise ValueError("num_output_channels should be either 1 or 3")
return img return img
...@@ -359,28 +359,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: ...@@ -359,28 +359,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
@torch.jit.unused @torch.jit.unused
def invert(img: Image.Image) -> Image.Image: def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.invert(img) return ImageOps.invert(img)
@torch.jit.unused @torch.jit.unused
def posterize(img: Image.Image, bits: int) -> Image.Image: def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.posterize(img, bits) return ImageOps.posterize(img, bits)
@torch.jit.unused @torch.jit.unused
def solarize(img: Image.Image, threshold: int) -> Image.Image: def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.solarize(img, threshold) return ImageOps.solarize(img, threshold)
@torch.jit.unused @torch.jit.unused
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
enhancer = ImageEnhance.Sharpness(img) enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor) img = enhancer.enhance(sharpness_factor)
...@@ -390,12 +390,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: ...@@ -390,12 +390,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
@torch.jit.unused @torch.jit.unused
def autocontrast(img: Image.Image) -> Image.Image: def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.autocontrast(img) return ImageOps.autocontrast(img)
@torch.jit.unused @torch.jit.unused
def equalize(img: Image.Image) -> Image.Image: def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return ImageOps.equalize(img) return ImageOps.equalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment