Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
511924c1
Unverified
Commit
511924c1
authored
Dec 06, 2022
by
Philip Meier
Committed by
GitHub
Dec 06, 2022
Browse files
add usage logging to prototype dispatchers / kernels (#7012)
parent
c65d57a5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
148 additions
and
0 deletions
+148
-0
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+6
-0
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+26
-0
torchvision/prototype/transforms/functional/_augment.py
torchvision/prototype/transforms/functional/_augment.py
+4
-0
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+35
-0
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+39
-0
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+26
-0
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+7
-0
torchvision/prototype/transforms/functional/_temporal.py
torchvision/prototype/transforms/functional/_temporal.py
+5
-0
No files found.
test/prototype_transforms_kernel_infos.py
View file @
511924c1
...
...
@@ -57,6 +57,9 @@ class KernelInfo(InfoBase):
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8
=
False
,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage
=
False
,
# See InfoBase
test_marks
=
None
,
# See InfoBase
...
...
@@ -71,6 +74,7 @@ class KernelInfo(InfoBase):
if
float32_vs_uint8
and
not
callable
(
float32_vs_uint8
):
float32_vs_uint8
=
lambda
other_args
,
kwargs
:
(
other_args
,
kwargs
)
# noqa: E731
self
.
float32_vs_uint8
=
float32_vs_uint8
self
.
logs_usage
=
logs_usage
def
_pixel_difference_closeness_kwargs
(
uint8_atol
,
*
,
dtype
=
torch
.
uint8
,
mae
=
False
):
...
...
@@ -675,6 +679,7 @@ KERNEL_INFOS.append(
sample_inputs_fn
=
sample_inputs_convert_format_bounding_box
,
reference_fn
=
reference_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
logs_usage
=
True
,
),
)
...
...
@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append(
KernelInfo
(
F
.
clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
logs_usage
=
True
,
)
)
...
...
test/test_prototype_transforms_functional.py
View file @
511924c1
...
...
@@ -108,6 +108,19 @@ class TestKernels:
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
)
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
KERNEL_INFOS
if
info
.
logs_usage
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs_fn
(),
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_logging
(
self
,
spy_on
,
info
,
args_kwargs
,
device
):
spy
=
spy_on
(
torch
.
_C
.
_log_api_usage_once
)
args
,
kwargs
=
args_kwargs
.
load
(
device
)
info
.
kernel
(
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
info
.
kernel
.
__module__
}
.
{
info
.
id
}
"
)
@
ignore_jit_warning_no_profile
@
sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
...
@@ -291,6 +304,19 @@ class TestDispatchers:
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
)
@
make_info_args_kwargs_parametrization
(
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_logging
(
self
,
spy_on
,
info
,
args_kwargs
,
device
):
spy
=
spy_on
(
torch
.
_C
.
_log_api_usage_once
)
args
,
kwargs
=
args_kwargs
.
load
(
device
)
info
.
dispatcher
(
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
info
.
dispatcher
.
__module__
}
.
{
info
.
id
}
"
)
@
ignore_jit_warning_no_profile
@
image_sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
...
torchvision/prototype/transforms/functional/_augment.py
View file @
511924c1
...
...
@@ -5,6 +5,7 @@ import PIL.Image
import
torch
from
torchvision.prototype
import
datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
def
erase_image_tensor
(
...
...
@@ -41,6 +42,9 @@ def erase(
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
torchvision/prototype/transforms/functional/_color.py
View file @
511924c1
...
...
@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
...
...
@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def
adjust_brightness
(
inpt
:
datapoints
.
InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def
adjust_saturation
(
inpt
:
datapoints
.
InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_saturation
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def
adjust_contrast
(
inpt
:
datapoints
.
InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def
adjust_sharpness
(
inpt
:
datapoints
.
InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_sharpness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def
adjust_hue
(
inpt
:
datapoints
.
InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def
adjust_gamma
(
inpt
:
datapoints
.
InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def
posterize
(
inpt
:
datapoints
.
InputTypeJIT
,
bits
:
int
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def
solarize
(
inpt
:
datapoints
.
InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def
autocontrast
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def
equalize
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
def
invert
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
511924c1
...
...
@@ -19,6 +19,8 @@ from torchvision.transforms.functional import (
)
from
torchvision.transforms.functional_tensor
import
_pad_symmetric
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
...
...
@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def
horizontal_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
def
vertical_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -231,6 +239,8 @@ def resize(
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -730,6 +740,9 @@ def affine(
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
...
...
@@ -913,6 +926,9 @@ def rotate(
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1120,6 +1136,9 @@ def pad(
fill
:
datapoints
.
FillTypeJIT
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def
crop
(
inpt
:
datapoints
.
InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1452,6 +1474,8 @@ def perspective(
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1612,6 +1636,9 @@ def elastic(
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def
center_crop
(
inpt
:
datapoints
.
InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1817,6 +1847,9 @@ def resized_crop(
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
def
five_crop
(
inpt
:
ImageOrVideoTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
five_crop
)
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
...
...
@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def
ten_crop
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Union
[
List
[
datapoints
.
ImageTypeJIT
],
List
[
datapoints
.
VideoTypeJIT
]]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
511924c1
...
...
@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.utils
import
_log_api_usage_once
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
...
...
@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int:
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[
def
get_spatial_size
(
inpt
:
datapoints
.
InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_spatial_size
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int:
def
get_num_frames
(
inpt
:
datapoints
.
VideoTypeJIT
)
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
...
...
@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
def
convert_format_bounding_box
(
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_box
)
if
new_format
==
old_format
:
return
bounding_box
...
...
@@ -199,6 +216,9 @@ def convert_format_bounding_box(
def
clamp_bounding_box
(
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_box
)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes
=
convert_format_bounding_box
(
...
...
@@ -313,6 +333,9 @@ def convert_color_space(
color_space
:
ColorSpace
,
old_color_space
:
Optional
[
ColorSpace
]
=
None
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def
convert_dtype
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_dtype
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
511924c1
...
...
@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad
from
torchvision.prototype
import
datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
..utils
import
is_simple_tensor
...
...
@@ -57,6 +59,8 @@ def normalize(
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
else
:
...
...
@@ -168,6 +172,9 @@ def gaussian_blur_video(
def
gaussian_blur
(
inpt
:
datapoints
.
InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
torchvision/prototype/transforms/functional/_temporal.py
View file @
511924c1
...
...
@@ -2,6 +2,8 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.utils
import
_log_api_usage_once
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
...
@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
VideoTypeJIT
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
datapoints
.
VideoTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment