Unverified Commit 55f7faf3 authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

Add api usage log to transforms (#5007)

* add api usage log to functional transforms

* add log to transforms

* fix for scriptablity

* skip Compose for scriptability

* follow the new policy

* torchscriptbility

* adopt new API

* make Compose scriptable

* move from __call__ to __init__
parent da7680f0
......@@ -14,6 +14,7 @@ try:
except ImportError:
accimage = None
from ..utils import _log_api_usage_once
from . import functional_pil as F_pil
from . import functional_tensor as F_t
......@@ -67,6 +68,8 @@ def get_image_size(img: Tensor) -> List[int]:
Returns:
List[int]: The image size.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_image_size)
if isinstance(img, torch.Tensor):
return F_t.get_image_size(img)
......@@ -82,6 +85,8 @@ def get_image_num_channels(img: Tensor) -> int:
Returns:
int: The number of channels.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_image_num_channels)
if isinstance(img, torch.Tensor):
return F_t.get_image_num_channels(img)
......@@ -110,6 +115,8 @@ def to_tensor(pic):
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_tensor)
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError(f"pic should be PIL Image or ndarray. Got {type(pic)}")
......@@ -166,6 +173,8 @@ def pil_to_tensor(pic):
Returns:
Tensor: Converted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pil_to_tensor)
if not F_pil._is_pil_image(pic):
raise TypeError(f"pic should be PIL Image. Got {type(pic)}")
......@@ -205,6 +214,8 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(convert_image_dtype)
if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor Image")
......@@ -225,6 +236,8 @@ def to_pil_image(pic, mode=None):
Returns:
PIL Image: Image converted to PIL Image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
......@@ -322,6 +335,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
Returns:
Tensor: Normalized Tensor image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(normalize)
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.")
......@@ -401,6 +416,8 @@ def resize(
Returns:
PIL Image or Tensor: Resized image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resize)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
......@@ -422,6 +439,8 @@ def resize(
def scale(*args, **kwargs):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(scale)
warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.")
return resize(*args, **kwargs)
......@@ -467,6 +486,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
Returns:
PIL Image or Tensor: Padded image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(pad)
if not isinstance(img, torch.Tensor):
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
......@@ -490,6 +511,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(crop)
if not isinstance(img, torch.Tensor):
return F_pil.crop(img, top, left, height, width)
......@@ -510,6 +533,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
Returns:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(center_crop)
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
......@@ -566,6 +591,8 @@ def resized_crop(
Returns:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resized_crop)
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation)
return img
......@@ -583,6 +610,8 @@ def hflip(img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: Horizontally flipped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(hflip)
if not isinstance(img, torch.Tensor):
return F_pil.hflip(img)
......@@ -648,6 +677,8 @@ def perspective(
Returns:
PIL Image or Tensor: transformed Image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(perspective)
coeffs = _get_perspective_coeffs(startpoints, endpoints)
......@@ -681,6 +712,8 @@ def vflip(img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: Vertically flipped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(vflip)
if not isinstance(img, torch.Tensor):
return F_pil.vflip(img)
......@@ -706,6 +739,8 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(five_crop)
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
......@@ -753,6 +788,8 @@ def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[
Corresponding top left, top right, bottom left, bottom right and
center crop and same for the flipped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(ten_crop)
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
......@@ -786,6 +823,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
Returns:
PIL Image or Tensor: Brightness adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_brightness)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)
......@@ -806,6 +845,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
Returns:
PIL Image or Tensor: Contrast adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_contrast)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)
......@@ -826,6 +867,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Returns:
PIL Image or Tensor: Saturation adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_saturation)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)
......@@ -860,6 +903,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Returns:
PIL Image or Tensor: Hue adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_hue)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)
......@@ -891,6 +936,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_gamma)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)
......@@ -987,6 +1034,8 @@ def rotate(
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(rotate)
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
......@@ -1067,6 +1116,8 @@ def affine(
Returns:
PIL Image or Tensor: Transformed image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(affine)
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
......@@ -1151,6 +1202,8 @@ def to_grayscale(img, num_output_channels=1):
- if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_grayscale)
if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels)
......@@ -1176,6 +1229,8 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
- if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(rgb_to_grayscale)
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)
......@@ -1198,6 +1253,8 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
Returns:
Tensor Image: Erased image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(erase)
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(img)}")
......@@ -1234,6 +1291,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
Returns:
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(gaussian_blur)
if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError(f"kernel_size should be int or a sequence of integers. Got {type(kernel_size)}")
if isinstance(kernel_size, int):
......@@ -1285,6 +1344,8 @@ def invert(img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: Color inverted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(invert)
if not isinstance(img, torch.Tensor):
return F_pil.invert(img)
......@@ -1304,6 +1365,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:
Returns:
PIL Image or Tensor: Posterized image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(posterize)
if not (0 <= bits <= 8):
raise ValueError(f"The number if bits should be between 0 and 8. Got {bits}")
......@@ -1325,6 +1388,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
Returns:
PIL Image or Tensor: Solarized image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(solarize)
if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold)
......@@ -1345,6 +1410,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
Returns:
PIL Image or Tensor: Sharpness adjusted image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(adjust_sharpness)
if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor)
......@@ -1365,6 +1432,8 @@ def autocontrast(img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: An image that was autocontrasted.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(autocontrast)
if not isinstance(img, torch.Tensor):
return F_pil.autocontrast(img)
......@@ -1386,6 +1455,8 @@ def equalize(img: Tensor) -> Tensor:
Returns:
PIL Image or Tensor: An image that was equalized.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(equalize)
if not isinstance(img, torch.Tensor):
return F_pil.equalize(img)
......
......@@ -13,6 +13,7 @@ try:
except ImportError:
accimage = None
from ..utils import _log_api_usage_once
from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int
......@@ -87,6 +88,8 @@ class Compose:
"""
def __init__(self, transforms):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(self)
self.transforms = transforms
def __call__(self, img):
......@@ -120,6 +123,9 @@ class ToTensor:
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
def __init__(self) -> None:
_log_api_usage_once(self)
def __call__(self, pic):
"""
Args:
......@@ -140,6 +146,9 @@ class PILToTensor:
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
"""
def __init__(self) -> None:
_log_api_usage_once(self)
def __call__(self, pic):
"""
.. note::
......@@ -179,6 +188,7 @@ class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
_log_api_usage_once(self)
self.dtype = dtype
def forward(self, image):
......@@ -204,6 +214,7 @@ class ToPILImage:
"""
def __init__(self, mode=None):
_log_api_usage_once(self)
self.mode = mode
def __call__(self, pic):
......@@ -245,6 +256,7 @@ class Normalize(torch.nn.Module):
def __init__(self, mean, std, inplace=False):
super().__init__()
_log_api_usage_once(self)
self.mean = mean
self.std = std
self.inplace = inplace
......@@ -309,6 +321,7 @@ class Resize(torch.nn.Module):
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
super().__init__()
_log_api_usage_once(self)
if not isinstance(size, (int, Sequence)):
raise TypeError(f"Size should be int or sequence. Got {type(size)}")
if isinstance(size, Sequence) and len(size) not in (1, 2):
......@@ -350,6 +363,7 @@ class Scale(Resize):
def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.")
super().__init__(*args, **kwargs)
_log_api_usage_once(self)
class CenterCrop(torch.nn.Module):
......@@ -366,6 +380,7 @@ class CenterCrop(torch.nn.Module):
def __init__(self, size):
super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
......@@ -422,6 +437,7 @@ class Pad(torch.nn.Module):
def __init__(self, padding, fill=0, padding_mode="constant"):
super().__init__()
_log_api_usage_once(self)
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
......@@ -462,6 +478,7 @@ class Lambda:
"""
def __init__(self, lambd):
_log_api_usage_once(self)
if not callable(lambd):
raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
self.lambd = lambd
......@@ -481,6 +498,7 @@ class RandomTransforms:
"""
def __init__(self, transforms):
_log_api_usage_once(self)
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.transforms = transforms
......@@ -519,6 +537,7 @@ class RandomApply(torch.nn.Module):
def __init__(self, transforms, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.transforms = transforms
self.p = p
......@@ -639,6 +658,7 @@ class RandomCrop(torch.nn.Module):
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__()
_log_api_usage_once(self)
self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
......@@ -688,6 +708,7 @@ class RandomHorizontalFlip(torch.nn.Module):
def __init__(self, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.p = p
def forward(self, img):
......@@ -718,6 +739,7 @@ class RandomVerticalFlip(torch.nn.Module):
def __init__(self, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.p = p
def forward(self, img):
......@@ -755,6 +777,7 @@ class RandomPerspective(torch.nn.Module):
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
super().__init__()
_log_api_usage_once(self)
self.p = p
# Backward compatibility with integer value
......@@ -867,6 +890,7 @@ class RandomResizedCrop(torch.nn.Module):
def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR):
super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
......@@ -963,6 +987,7 @@ class RandomSizedCrop(RandomResizedCrop):
+ "please use transforms.RandomResizedCrop instead."
)
super().__init__(*args, **kwargs)
_log_api_usage_once(self)
class FiveCrop(torch.nn.Module):
......@@ -995,6 +1020,7 @@ class FiveCrop(torch.nn.Module):
def __init__(self, size):
super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
......@@ -1043,6 +1069,7 @@ class TenCrop(torch.nn.Module):
def __init__(self, size, vertical_flip=False):
super().__init__()
_log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
......@@ -1081,6 +1108,7 @@ class LinearTransformation(torch.nn.Module):
def __init__(self, transformation_matrix, mean_vector):
super().__init__()
_log_api_usage_once(self)
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
......@@ -1159,6 +1187,7 @@ class ColorJitter(torch.nn.Module):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
_log_api_usage_once(self)
self.brightness = self._check_input(brightness, "brightness")
self.contrast = self._check_input(contrast, "contrast")
self.saturation = self._check_input(saturation, "saturation")
......@@ -1281,6 +1310,7 @@ class RandomRotation(torch.nn.Module):
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
):
super().__init__()
_log_api_usage_once(self)
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
......@@ -1401,6 +1431,7 @@ class RandomAffine(torch.nn.Module):
resample=None,
):
super().__init__()
_log_api_usage_once(self)
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
......@@ -1545,6 +1576,7 @@ class Grayscale(torch.nn.Module):
def __init__(self, num_output_channels=1):
super().__init__()
_log_api_usage_once(self)
self.num_output_channels = num_output_channels
def forward(self, img):
......@@ -1579,6 +1611,7 @@ class RandomGrayscale(torch.nn.Module):
def __init__(self, p=0.1):
super().__init__()
_log_api_usage_once(self)
self.p = p
def forward(self, img):
......@@ -1628,6 +1661,7 @@ class RandomErasing(torch.nn.Module):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
super().__init__()
_log_api_usage_once(self)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
......@@ -1751,6 +1785,7 @@ class GaussianBlur(torch.nn.Module):
def __init__(self, kernel_size, sigma=(0.1, 2.0)):
super().__init__()
_log_api_usage_once(self)
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
......@@ -1842,6 +1877,7 @@ class RandomInvert(torch.nn.Module):
def __init__(self, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.p = p
def forward(self, img):
......@@ -1873,6 +1909,7 @@ class RandomPosterize(torch.nn.Module):
def __init__(self, bits, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.bits = bits
self.p = p
......@@ -1905,6 +1942,7 @@ class RandomSolarize(torch.nn.Module):
def __init__(self, threshold, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.threshold = threshold
self.p = p
......@@ -1937,6 +1975,7 @@ class RandomAdjustSharpness(torch.nn.Module):
def __init__(self, sharpness_factor, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.sharpness_factor = sharpness_factor
self.p = p
......@@ -1968,6 +2007,7 @@ class RandomAutocontrast(torch.nn.Module):
def __init__(self, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.p = p
def forward(self, img):
......@@ -1998,6 +2038,7 @@ class RandomEqualize(torch.nn.Module):
def __init__(self, p=0.5):
super().__init__()
_log_api_usage_once(self)
self.p = p
def forward(self, img):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment