Unverified Commit 511924c1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add usage logging to prototype dispatchers / kernels (#7012)

parent c65d57a5
......@@ -57,6 +57,9 @@ class KernelInfo(InfoBase):
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8=False,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage=False,
# See InfoBase
test_marks=None,
# See InfoBase
......@@ -71,6 +74,7 @@ class KernelInfo(InfoBase):
if float32_vs_uint8 and not callable(float32_vs_uint8):
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
self.float32_vs_uint8 = float32_vs_uint8
self.logs_usage = logs_usage
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
......@@ -675,6 +679,7 @@ KERNEL_INFOS.append(
sample_inputs_fn=sample_inputs_convert_format_bounding_box,
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
),
)
......@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append(
KernelInfo(
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
)
)
......
......@@ -108,6 +108,19 @@ class TestKernels:
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.logs_usage],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device)
info.kernel(*args, **kwargs)
spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
@ignore_jit_warning_no_profile
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -291,6 +304,19 @@ class TestDispatchers:
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
)
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device)
info.dispatcher(*args, **kwargs)
spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")
@ignore_jit_warning_no_profile
@image_sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......
......@@ -5,6 +5,7 @@ import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
def erase_image_tensor(
......@@ -41,6 +42,9 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
......
......@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
......@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(posterize)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(solarize)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......
......@@ -19,6 +19,8 @@ from torchvision.transforms.functional import (
)
from torchvision.transforms.functional_tensor import _pad_symmetric
from torchvision.utils import _log_api_usage_once
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
......@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -231,6 +239,8 @@ def resize(
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resize)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -730,6 +740,9 @@ def affine(
fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(affine)
# TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
......@@ -913,6 +926,9 @@ def rotate(
center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(rotate)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1120,6 +1136,9 @@ def pad(
fill: datapoints.FillTypeJIT = None,
padding_mode: str = "constant",
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(pad)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(crop)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1452,6 +1474,8 @@ def perspective(
fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(perspective)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1612,6 +1636,9 @@ def elastic(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(elastic)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(center_crop)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1817,6 +1847,9 @@ def resized_crop(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(five_crop)
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if isinstance(inpt, torch.Tensor) and (
......@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
......
......@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:])
......@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
......@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int:
def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_channels)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
......@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[
def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_spatial_size)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int:
def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_frames)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video):
......@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
def convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)
if new_format == old_format:
return bounding_box
......@@ -199,6 +216,9 @@ def convert_format_bounding_box(
def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box(
......@@ -313,6 +333,9 @@ def convert_color_space(
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_color_space)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
......@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def convert_dtype(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_dtype)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
):
......
......@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
from ..utils import is_simple_tensor
......@@ -57,6 +59,8 @@ def normalize(
inplace: bool = False,
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(normalize)
if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
else:
......@@ -168,6 +172,9 @@ def gaussian_blur_video(
def gaussian_blur(
inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(gaussian_blur)
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
):
......
......@@ -2,6 +2,8 @@ import torch
from torchvision.prototype import datapoints
from torchvision.utils import _log_api_usage_once
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
......@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def uniform_temporal_subsample(
inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> datapoints.VideoTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
elif isinstance(inpt, datapoints.Video):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment