Unverified Commit 511924c1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add usage logging to prototype dispatchers / kernels (#7012)

parent c65d57a5
...@@ -57,6 +57,9 @@ class KernelInfo(InfoBase): ...@@ -57,6 +57,9 @@ class KernelInfo(InfoBase):
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype. # dtype.
float32_vs_uint8=False, float32_vs_uint8=False,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage=False,
# See InfoBase # See InfoBase
test_marks=None, test_marks=None,
# See InfoBase # See InfoBase
...@@ -71,6 +74,7 @@ class KernelInfo(InfoBase): ...@@ -71,6 +74,7 @@ class KernelInfo(InfoBase):
if float32_vs_uint8 and not callable(float32_vs_uint8): if float32_vs_uint8 and not callable(float32_vs_uint8):
float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731 float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs) # noqa: E731
self.float32_vs_uint8 = float32_vs_uint8 self.float32_vs_uint8 = float32_vs_uint8
self.logs_usage = logs_usage
def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False): def _pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
...@@ -675,6 +679,7 @@ KERNEL_INFOS.append( ...@@ -675,6 +679,7 @@ KERNEL_INFOS.append(
sample_inputs_fn=sample_inputs_convert_format_bounding_box, sample_inputs_fn=sample_inputs_convert_format_bounding_box,
reference_fn=reference_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
), ),
) )
...@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append( ...@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append(
KernelInfo( KernelInfo(
F.clamp_bounding_box, F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
) )
) )
......
...@@ -108,6 +108,19 @@ class TestKernels: ...@@ -108,6 +108,19 @@ class TestKernels:
args_kwargs_fn=lambda info: info.reference_inputs_fn(), args_kwargs_fn=lambda info: info.reference_inputs_fn(),
) )
@make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.logs_usage],
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device)
info.kernel(*args, **kwargs)
spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -291,6 +304,19 @@ class TestDispatchers: ...@@ -291,6 +304,19 @@ class TestDispatchers:
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
) )
@make_info_args_kwargs_parametrization(
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_logging(self, spy_on, info, args_kwargs, device):
spy = spy_on(torch._C._log_api_usage_once)
args, kwargs = args_kwargs.load(device)
info.dispatcher(*args, **kwargs)
spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")
@ignore_jit_warning_no_profile @ignore_jit_warning_no_profile
@image_sample_inputs @image_sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
......
...@@ -5,6 +5,7 @@ import PIL.Image ...@@ -5,6 +5,7 @@ import PIL.Image
import torch import torch
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
def erase_image_tensor( def erase_image_tensor(
...@@ -41,6 +42,9 @@ def erase( ...@@ -41,6 +42,9 @@ def erase(
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
......
...@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints ...@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
...@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT: def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT: def adjust_saturation(inpt: datapoints.InputTypeJIT, saturation_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT: def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT: def adjust_sharpness(inpt: datapoints.InputTypeJIT, sharpness_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT: def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT: def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT: def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(posterize)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT: def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(solarize)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(equalize)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: ...@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
......
...@@ -19,6 +19,8 @@ from torchvision.transforms.functional import ( ...@@ -19,6 +19,8 @@ from torchvision.transforms.functional import (
) )
from torchvision.transforms.functional_tensor import _pad_symmetric from torchvision.transforms.functional_tensor import _pad_symmetric
from torchvision.utils import _log_api_usage_once
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
...@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: ...@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: ...@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT: def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -231,6 +239,8 @@ def resize( ...@@ -231,6 +239,8 @@ def resize(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resize)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -730,6 +740,9 @@ def affine( ...@@ -730,6 +740,9 @@ def affine(
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(affine)
# TODO: consider deprecating integers from angle and shear on the future # TODO: consider deprecating integers from angle and shear on the future
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
...@@ -913,6 +926,9 @@ def rotate( ...@@ -913,6 +926,9 @@ def rotate(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(rotate)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1120,6 +1136,9 @@ def pad( ...@@ -1120,6 +1136,9 @@ def pad(
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(pad)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int ...@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT: def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(crop)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1452,6 +1474,8 @@ def perspective( ...@@ -1452,6 +1474,8 @@ def perspective(
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(perspective)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1612,6 +1636,9 @@ def elastic( ...@@ -1612,6 +1636,9 @@ def elastic(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(elastic)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens ...@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT: def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(center_crop)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1817,6 +1847,9 @@ def resized_crop( ...@@ -1817,6 +1847,9 @@ def resized_crop(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT] ...@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
def five_crop( def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int] inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(five_crop)
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop` # `ten_crop`
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
...@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F ...@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop( def ten_crop(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: ) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
......
...@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace ...@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms.functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
chw = list(image.shape[-3:]) chw = list(image.shape[-3:])
...@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions ...@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
...@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int: ...@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int:
def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int: def get_num_channels(inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_channels)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
...@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[ ...@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[
def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]: def get_spatial_size(inpt: datapoints.InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_spatial_size)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
...@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int: ...@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int:
def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int: def get_num_frames(inpt: datapoints.VideoTypeJIT) -> int:
if not torch.jit.is_scripting():
_log_api_usage_once(get_num_frames)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
...@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: ...@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
def convert_format_bounding_box( def convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)
if new_format == old_format: if new_format == old_format:
return bounding_box return bounding_box
...@@ -199,6 +216,9 @@ def convert_format_bounding_box( ...@@ -199,6 +216,9 @@ def convert_format_bounding_box(
def clamp_bounding_box( def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth # BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box( xyxy_boxes = convert_format_bounding_box(
...@@ -313,6 +333,9 @@ def convert_color_space( ...@@ -313,6 +333,9 @@ def convert_color_space(
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_color_space)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
...@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) - ...@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def convert_dtype( def convert_dtype(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], dtype: torch.dtype = torch.float
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_dtype)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
): ):
......
...@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad ...@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
from ..utils import is_simple_tensor from ..utils import is_simple_tensor
...@@ -57,6 +59,8 @@ def normalize( ...@@ -57,6 +59,8 @@ def normalize(
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(normalize)
if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)): if is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor) inpt = inpt.as_subclass(torch.Tensor)
else: else:
...@@ -168,6 +172,9 @@ def gaussian_blur_video( ...@@ -168,6 +172,9 @@ def gaussian_blur_video(
def gaussian_blur( def gaussian_blur(
inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None inpt: datapoints.InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(gaussian_blur)
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
): ):
......
...@@ -2,6 +2,8 @@ import torch ...@@ -2,6 +2,8 @@ import torch
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.utils import _log_api_usage_once
def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor: def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temporal_dim: int = -4) -> torch.Tensor:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19 # Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp ...@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def uniform_temporal_subsample( def uniform_temporal_subsample(
inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4 inpt: datapoints.VideoTypeJIT, num_samples: int, temporal_dim: int = -4
) -> datapoints.VideoTypeJIT: ) -> datapoints.VideoTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(uniform_temporal_subsample)
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Video)):
return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim) return uniform_temporal_subsample_video(inpt, num_samples, temporal_dim=temporal_dim)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment