Unverified Commit 55f7faf3 authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

Add api usage log to transforms (#5007)

* add api usage log to functional transforms

* add log to transforms

* fix for scriptablity

* skip Compose for scriptability

* follow the new policy

* torchscriptbility

* adopt new API

* make Compose scriptable

* move from __call__ to __init__
parent da7680f0
...@@ -14,6 +14,7 @@ try: ...@@ -14,6 +14,7 @@ try:
except ImportError: except ImportError:
accimage = None accimage = None
from ..utils import _log_api_usage_once
from . import functional_pil as F_pil from . import functional_pil as F_pil
from . import functional_tensor as F_t from . import functional_tensor as F_t
...@@ -67,6 +68,8 @@ def get_image_size(img: Tensor) -> List[int]: ...@@ -67,6 +68,8 @@ def get_image_size(img: Tensor) -> List[int]:
Returns: Returns:
List[int]: The image size. List[int]: The image size.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_image_size)
if isinstance(img, torch.Tensor): if isinstance(img, torch.Tensor):
return F_t.get_image_size(img) return F_t.get_image_size(img)
...@@ -82,6 +85,8 @@ def get_image_num_channels(img: Tensor) -> int: ...@@ -82,6 +85,8 @@ def get_image_num_channels(img: Tensor) -> int:
Returns: Returns:
int: The number of channels. int: The number of channels.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_image_num_channels)
if isinstance(img, torch.Tensor): if isinstance(img, torch.Tensor):
return F_t.get_image_num_channels(img) return F_t.get_image_num_channels(img)
...@@ -110,6 +115,8 @@ def to_tensor(pic): ...@@ -110,6 +115,8 @@ def to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_tensor)
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)): if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}") raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
...@@ -166,6 +173,8 @@ def pil_to_tensor(pic): ...@@ -166,6 +173,8 @@ def pil_to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pil_to_tensor)
if not F_pil._is_pil_image(pic): if not F_pil._is_pil_image(pic):
raise TypeError(f"pic should be PIL Image. Got {type(pic)}") raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
...@@ -205,6 +214,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -205,6 +214,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``. of the integer ``dtype``.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(convert_image_dtype)
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor Image") raise TypeError("Input img should be Tensor Image")
...@@ -225,6 +236,8 @@ def to_pil_image(pic, mode=None): ...@@ -225,6 +236,8 @@ def to_pil_image(pic, mode=None):
Returns: Returns:
PIL Image: Image converted to PIL Image. PIL Image: Image converted to PIL Image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
...@@ -322,6 +335,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -322,6 +335,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
Returns: Returns:
Tensor: Normalized Tensor image. Tensor: Normalized Tensor image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(normalize)
if not isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.") raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.")
...@@ -401,6 +416,8 @@ def resize( ...@@ -401,6 +416,8 @@ def resize(
Returns: Returns:
PIL Image or Tensor: Resized image. PIL Image or Tensor: Resized image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resize)
# Backward compatibility with integer value # Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
warnings.warn( warnings.warn(
...@@ -422,6 +439,8 @@ def resize( ...@@ -422,6 +439,8 @@ def resize(
def scale(*args, **kwargs): def scale(*args, **kwargs):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(scale)
warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.") warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.")
return resize(*args, **kwargs) return resize(*args, **kwargs)
...@@ -467,6 +486,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -467,6 +486,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
Returns: Returns:
PIL Image or Tensor: Padded image. PIL Image or Tensor: Padded image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pad)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
...@@ -490,6 +511,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: ...@@ -490,6 +511,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
PIL Image or Tensor: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(crop)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.crop(img, top, left, height, width) return F_pil.crop(img, top, left, height, width)
...@@ -510,6 +533,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -510,6 +533,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(center_crop)
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size)) output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
...@@ -566,6 +591,8 @@ def resized_crop( ...@@ -566,6 +591,8 @@ def resized_crop(
Returns: Returns:
PIL Image or Tensor: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resized_crop)
img = crop(img, top, left, height, width) img = crop(img, top, left, height, width)
img = resize(img, size, interpolation) img = resize(img, size, interpolation)
return img return img
...@@ -583,6 +610,8 @@ def hflip(img: Tensor) -> Tensor: ...@@ -583,6 +610,8 @@ def hflip(img: Tensor) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Horizontally flipped image. PIL Image or Tensor: Horizontally flipped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(hflip)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.hflip(img) return F_pil.hflip(img)
...@@ -648,6 +677,8 @@ def perspective( ...@@ -648,6 +677,8 @@ def perspective(
Returns: Returns:
PIL Image or Tensor: transformed Image. PIL Image or Tensor: transformed Image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(perspective)
coeffs = _get_perspective_coeffs(startpoints, endpoints) coeffs = _get_perspective_coeffs(startpoints, endpoints)
...@@ -681,6 +712,8 @@ def vflip(img: Tensor) -> Tensor: ...@@ -681,6 +712,8 @@ def vflip(img: Tensor) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Vertically flipped image. PIL Image or Tensor: Vertically flipped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(vflip)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.vflip(img) return F_pil.vflip(img)
...@@ -706,6 +739,8 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten ...@@ -706,6 +739,8 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
tuple: tuple (tl, tr, bl, br, center) tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop. Corresponding top left, top right, bottom left, bottom right and center crop.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(five_crop)
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = (int(size), int(size)) size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1: elif isinstance(size, (tuple, list)) and len(size) == 1:
...@@ -753,6 +788,8 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[ ...@@ -753,6 +788,8 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[
Corresponding top left, top right, bottom left, bottom right and Corresponding top left, top right, bottom left, bottom right and
center crop and same for the flipped image. center crop and same for the flipped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(ten_crop)
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = (int(size), int(size)) size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1: elif isinstance(size, (tuple, list)) and len(size) == 1:
...@@ -786,6 +823,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: ...@@ -786,6 +823,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Brightness adjusted image. PIL Image or Tensor: Brightness adjusted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_brightness)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor) return F_pil.adjust_brightness(img, brightness_factor)
...@@ -806,6 +845,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -806,6 +845,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Contrast adjusted image. PIL Image or Tensor: Contrast adjusted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_contrast)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor) return F_pil.adjust_contrast(img, contrast_factor)
...@@ -826,6 +867,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: ...@@ -826,6 +867,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Saturation adjusted image. PIL Image or Tensor: Saturation adjusted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_saturation)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor) return F_pil.adjust_saturation(img, saturation_factor)
...@@ -860,6 +903,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -860,6 +903,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Hue adjusted image. PIL Image or Tensor: Hue adjusted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_hue)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor) return F_pil.adjust_hue(img, hue_factor)
...@@ -891,6 +936,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: ...@@ -891,6 +936,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Gamma correction adjusted image. PIL Image or Tensor: Gamma correction adjusted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_gamma)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain) return F_pil.adjust_gamma(img, gamma, gain)
...@@ -987,6 +1034,8 @@ def rotate( ...@@ -987,6 +1034,8 @@ def rotate(
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(rotate)
if resample is not None: if resample is not None:
warnings.warn( warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
...@@ -1067,6 +1116,8 @@ def affine( ...@@ -1067,6 +1116,8 @@ def affine(
Returns: Returns:
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(affine)
if resample is not None: if resample is not None:
warnings.warn( warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
...@@ -1151,6 +1202,8 @@ def to_grayscale(img, num_output_channels=1): ...@@ -1151,6 +1202,8 @@ def to_grayscale(img, num_output_channels=1):
- if num_output_channels = 1 : returned image is single channel - if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b - if num_output_channels = 3 : returned image is 3 channel with r = g = b
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_grayscale)
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels) return F_pil.to_grayscale(img, num_output_channels)
...@@ -1176,6 +1229,8 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: ...@@ -1176,6 +1229,8 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
- if num_output_channels = 1 : returned image is single channel - if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b - if num_output_channels = 3 : returned image is 3 channel with r = g = b
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(rgb_to_grayscale)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels) return F_pil.to_grayscale(img, num_output_channels)
...@@ -1198,6 +1253,8 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool ...@@ -1198,6 +1253,8 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
Returns: Returns:
Tensor Image: Erased image. Tensor Image: Erased image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(erase)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(img)}") raise TypeError(f"img should be Tensor Image. Got {type(img)}")
...@@ -1234,6 +1291,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa ...@@ -1234,6 +1291,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
Returns: Returns:
PIL Image or Tensor: Gaussian Blurred version of the image. PIL Image or Tensor: Gaussian Blurred version of the image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(gaussian_blur)
if not isinstance(kernel_size, (int, list, tuple)): if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}") raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
...@@ -1285,6 +1344,8 @@ def invert(img: Tensor) -> Tensor: ...@@ -1285,6 +1344,8 @@ def invert(img: Tensor) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Color inverted image. PIL Image or Tensor: Color inverted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(invert)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.invert(img) return F_pil.invert(img)
...@@ -1304,6 +1365,8 @@ def posterize(img: Tensor, bits: int) -> Tensor: ...@@ -1304,6 +1365,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Posterized image. PIL Image or Tensor: Posterized image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(posterize)
if not (0 <= bits <= 8): if not (0 <= bits <= 8):
raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}") raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
...@@ -1325,6 +1388,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor: ...@@ -1325,6 +1388,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Solarized image. PIL Image or Tensor: Solarized image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(solarize)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold) return F_pil.solarize(img, threshold)
...@@ -1345,6 +1410,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: ...@@ -1345,6 +1410,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
Returns: Returns:
PIL Image or Tensor: Sharpness adjusted image. PIL Image or Tensor: Sharpness adjusted image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_sharpness)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor) return F_pil.adjust_sharpness(img, sharpness_factor)
...@@ -1365,6 +1432,8 @@ def autocontrast(img: Tensor) -> Tensor: ...@@ -1365,6 +1432,8 @@ def autocontrast(img: Tensor) -> Tensor:
Returns: Returns:
PIL Image or Tensor: An image that was autocontrasted. PIL Image or Tensor: An image that was autocontrasted.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(autocontrast)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.autocontrast(img) return F_pil.autocontrast(img)
...@@ -1386,6 +1455,8 @@ def equalize(img: Tensor) -> Tensor: ...@@ -1386,6 +1455,8 @@ def equalize(img: Tensor) -> Tensor:
Returns: Returns:
PIL Image or Tensor: An image that was equalized. PIL Image or Tensor: An image that was equalized.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(equalize)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.equalize(img) return F_pil.equalize(img)
......
...@@ -13,6 +13,7 @@ try: ...@@ -13,6 +13,7 @@ try:
except ImportError: except ImportError:
accimage = None accimage = None
from ..utils import _log_api_usage_once
from . import functional as F from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int from .functional import InterpolationMode, _interpolation_modes_from_int
...@@ -87,6 +88,8 @@ class Compose: ...@@ -87,6 +88,8 @@ class Compose:
""" """
def __init__(self, transforms): def __init__(self, transforms):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(self)
self.transforms = transforms self.transforms = transforms
def __call__(self, img): def __call__(self, img):
...@@ -120,6 +123,9 @@ class ToTensor: ...@@ -120,6 +123,9 @@ class ToTensor:
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
""" """
def __init__(self) -> None:
_log_api_usage_once(self)
def __call__(self, pic): def __call__(self, pic):
""" """
Args: Args:
...@@ -140,6 +146,9 @@ class PILToTensor: ...@@ -140,6 +146,9 @@ class PILToTensor:
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
""" """
def __init__(self) -> None:
_log_api_usage_once(self)
def __call__(self, pic): def __call__(self, pic):
""" """
.. note:: .. note::
...@@ -179,6 +188,7 @@ class ConvertImageDtype(torch.nn.Module): ...@@ -179,6 +188,7 @@ class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype: torch.dtype) -> None: def __init__(self, dtype: torch.dtype) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self)
self.dtype = dtype self.dtype = dtype
def forward(self, image): def forward(self, image):
...@@ -204,6 +214,7 @@ class ToPILImage: ...@@ -204,6 +214,7 @@ class ToPILImage:
""" """
def __init__(self, mode=None): def __init__(self, mode=None):
_log_api_usage_once(self)
self.mode = mode self.mode = mode
def __call__(self, pic): def __call__(self, pic):
...@@ -245,6 +256,7 @@ class Normalize(torch.nn.Module): ...@@ -245,6 +256,7 @@ class Normalize(torch.nn.Module):
def __init__(self, mean, std, inplace=False): def __init__(self, mean, std, inplace=False):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.mean = mean self.mean = mean
self.std = std self.std = std
self.inplace = inplace self.inplace = inplace
...@@ -309,6 +321,7 @@ class Resize(torch.nn.Module): ...@@ -309,6 +321,7 @@ class Resize(torch.nn.Module):
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
super().__init__() super().__init__()
_log_api_usage_once(self)
if not isinstance(size, (int, Sequence)): if not isinstance(size, (int, Sequence)):
raise TypeError(f"Size should be int or sequence. Got {type(size)}") raise TypeError(f"Size should be int or sequence. Got {type(size)}")
if isinstance(size, Sequence) and len(size) not in (1, 2): if isinstance(size, Sequence) and len(size) not in (1, 2):
...@@ -350,6 +363,7 @@ class Scale(Resize): ...@@ -350,6 +363,7 @@ class Scale(Resize):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.") warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
_log_api_usage_once(self)
class CenterCrop(torch.nn.Module): class CenterCrop(torch.nn.Module):
...@@ -366,6 +380,7 @@ class CenterCrop(torch.nn.Module): ...@@ -366,6 +380,7 @@ class CenterCrop(torch.nn.Module):
def __init__(self, size): def __init__(self, size):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img): def forward(self, img):
...@@ -422,6 +437,7 @@ class Pad(torch.nn.Module): ...@@ -422,6 +437,7 @@ class Pad(torch.nn.Module):
def __init__(self, padding, fill=0, padding_mode="constant"): def __init__(self, padding, fill=0, padding_mode="constant"):
super().__init__() super().__init__()
_log_api_usage_once(self)
if not isinstance(padding, (numbers.Number, tuple, list)): if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg") raise TypeError("Got inappropriate padding arg")
...@@ -462,6 +478,7 @@ class Lambda: ...@@ -462,6 +478,7 @@ class Lambda:
""" """
def __init__(self, lambd): def __init__(self, lambd):
_log_api_usage_once(self)
if not callable(lambd): if not callable(lambd):
raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}") raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
self.lambd = lambd self.lambd = lambd
...@@ -481,6 +498,7 @@ class RandomTransforms: ...@@ -481,6 +498,7 @@ class RandomTransforms:
""" """
def __init__(self, transforms): def __init__(self, transforms):
_log_api_usage_once(self)
if not isinstance(transforms, Sequence): if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence") raise TypeError("Argument transforms should be a sequence")
self.transforms = transforms self.transforms = transforms
...@@ -519,6 +537,7 @@ class RandomApply(torch.nn.Module): ...@@ -519,6 +537,7 @@ class RandomApply(torch.nn.Module):
def __init__(self, transforms, p=0.5): def __init__(self, transforms, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.transforms = transforms self.transforms = transforms
self.p = p self.p = p
...@@ -639,6 +658,7 @@ class RandomCrop(torch.nn.Module): ...@@ -639,6 +658,7 @@ class RandomCrop(torch.nn.Module):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
...@@ -688,6 +708,7 @@ class RandomHorizontalFlip(torch.nn.Module): ...@@ -688,6 +708,7 @@ class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
def forward(self, img): def forward(self, img):
...@@ -718,6 +739,7 @@ class RandomVerticalFlip(torch.nn.Module): ...@@ -718,6 +739,7 @@ class RandomVerticalFlip(torch.nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
def forward(self, img): def forward(self, img):
...@@ -755,6 +777,7 @@ class RandomPerspective(torch.nn.Module): ...@@ -755,6 +777,7 @@ class RandomPerspective(torch.nn.Module):
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
# Backward compatibility with integer value # Backward compatibility with integer value
...@@ -867,6 +890,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -867,6 +890,7 @@ class RandomResizedCrop(torch.nn.Module):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR): def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence): if not isinstance(scale, Sequence):
...@@ -963,6 +987,7 @@ class RandomSizedCrop(RandomResizedCrop): ...@@ -963,6 +987,7 @@ class RandomSizedCrop(RandomResizedCrop):
+ "please use transforms.RandomResizedCrop instead." + "please use transforms.RandomResizedCrop instead."
) )
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
_log_api_usage_once(self)
class FiveCrop(torch.nn.Module): class FiveCrop(torch.nn.Module):
...@@ -995,6 +1020,7 @@ class FiveCrop(torch.nn.Module): ...@@ -995,6 +1020,7 @@ class FiveCrop(torch.nn.Module):
def __init__(self, size): def __init__(self, size):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img): def forward(self, img):
...@@ -1043,6 +1069,7 @@ class TenCrop(torch.nn.Module): ...@@ -1043,6 +1069,7 @@ class TenCrop(torch.nn.Module):
def __init__(self, size, vertical_flip=False): def __init__(self, size, vertical_flip=False):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
...@@ -1081,6 +1108,7 @@ class LinearTransformation(torch.nn.Module): ...@@ -1081,6 +1108,7 @@ class LinearTransformation(torch.nn.Module):
def __init__(self, transformation_matrix, mean_vector): def __init__(self, transformation_matrix, mean_vector):
super().__init__() super().__init__()
_log_api_usage_once(self)
if transformation_matrix.size(0) != transformation_matrix.size(1): if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError( raise ValueError(
"transformation_matrix should be square. Got " "transformation_matrix should be square. Got "
...@@ -1159,6 +1187,7 @@ class ColorJitter(torch.nn.Module): ...@@ -1159,6 +1187,7 @@ class ColorJitter(torch.nn.Module):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.brightness = self._check_input(brightness, "brightness") self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast") self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation") self.saturation = self._check_input(saturation, "saturation")
...@@ -1281,6 +1310,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1281,6 +1310,7 @@ class RandomRotation(torch.nn.Module):
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
): ):
super().__init__() super().__init__()
_log_api_usage_once(self)
if resample is not None: if resample is not None:
warnings.warn( warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
...@@ -1401,6 +1431,7 @@ class RandomAffine(torch.nn.Module): ...@@ -1401,6 +1431,7 @@ class RandomAffine(torch.nn.Module):
resample=None, resample=None,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self)
if resample is not None: if resample is not None:
warnings.warn( warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
...@@ -1545,6 +1576,7 @@ class Grayscale(torch.nn.Module): ...@@ -1545,6 +1576,7 @@ class Grayscale(torch.nn.Module):
def __init__(self, num_output_channels=1): def __init__(self, num_output_channels=1):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def forward(self, img): def forward(self, img):
...@@ -1579,6 +1611,7 @@ class RandomGrayscale(torch.nn.Module): ...@@ -1579,6 +1611,7 @@ class RandomGrayscale(torch.nn.Module):
def __init__(self, p=0.1): def __init__(self, p=0.1):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
def forward(self, img): def forward(self, img):
...@@ -1628,6 +1661,7 @@ class RandomErasing(torch.nn.Module): ...@@ -1628,6 +1661,7 @@ class RandomErasing(torch.nn.Module):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
super().__init__() super().__init__()
_log_api_usage_once(self)
if not isinstance(value, (numbers.Number, str, tuple, list)): if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence") raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random": if isinstance(value, str) and value != "random":
...@@ -1751,6 +1785,7 @@ class GaussianBlur(torch.nn.Module): ...@@ -1751,6 +1785,7 @@ class GaussianBlur(torch.nn.Module):
def __init__(self, kernel_size, sigma=(0.1, 2.0)): def __init__(self, kernel_size, sigma=(0.1, 2.0)):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers") self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size: for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0: if ks <= 0 or ks % 2 == 0:
...@@ -1842,6 +1877,7 @@ class RandomInvert(torch.nn.Module): ...@@ -1842,6 +1877,7 @@ class RandomInvert(torch.nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
def forward(self, img): def forward(self, img):
...@@ -1873,6 +1909,7 @@ class RandomPosterize(torch.nn.Module): ...@@ -1873,6 +1909,7 @@ class RandomPosterize(torch.nn.Module):
def __init__(self, bits, p=0.5): def __init__(self, bits, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.bits = bits self.bits = bits
self.p = p self.p = p
...@@ -1905,6 +1942,7 @@ class RandomSolarize(torch.nn.Module): ...@@ -1905,6 +1942,7 @@ class RandomSolarize(torch.nn.Module):
def __init__(self, threshold, p=0.5): def __init__(self, threshold, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.threshold = threshold self.threshold = threshold
self.p = p self.p = p
...@@ -1937,6 +1975,7 @@ class RandomAdjustSharpness(torch.nn.Module): ...@@ -1937,6 +1975,7 @@ class RandomAdjustSharpness(torch.nn.Module):
def __init__(self, sharpness_factor, p=0.5): def __init__(self, sharpness_factor, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.sharpness_factor = sharpness_factor self.sharpness_factor = sharpness_factor
self.p = p self.p = p
...@@ -1968,6 +2007,7 @@ class RandomAutocontrast(torch.nn.Module): ...@@ -1968,6 +2007,7 @@ class RandomAutocontrast(torch.nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
def forward(self, img): def forward(self, img):
...@@ -1998,6 +2038,7 @@ class RandomEqualize(torch.nn.Module): ...@@ -1998,6 +2038,7 @@ class RandomEqualize(torch.nn.Module):
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
def forward(self, img): def forward(self, img):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment