Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
511924c1
"vscode:/vscode.git/clone" did not exist on "b4ad59d77f3fff30f148a0391b4cdfc6ef19915c"
Unverified
Commit
511924c1
authored
Dec 06, 2022
by
Philip Meier
Committed by
GitHub
Dec 06, 2022
Browse files
add usage logging to prototype dispatchers / kernels (#7012)
parent
c65d57a5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
148 additions
and
0 deletions
+148
-0
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+6
-0
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+26
-0
torchvision/prototype/transforms/functional/_augment.py
torchvision/prototype/transforms/functional/_augment.py
+4
-0
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+35
-0
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+39
-0
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+26
-0
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+7
-0
torchvision/prototype/transforms/functional/_temporal.py
torchvision/prototype/transforms/functional/_temporal.py
+5
-0
No files found.
test/prototype_transforms_kernel_infos.py
View file @
511924c1
...
...
@@ -57,6 +57,9 @@ class KernelInfo(InfoBase):
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
float32_vs_uint8
=
False
,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage
=
False
,
# See InfoBase
test_marks
=
None
,
# See InfoBase
...
...
@@ -71,6 +74,7 @@ class KernelInfo(InfoBase):
if
float32_vs_uint8
and
not
callable
(
float32_vs_uint8
):
float32_vs_uint8
=
lambda
other_args
,
kwargs
:
(
other_args
,
kwargs
)
# noqa: E731
self
.
float32_vs_uint8
=
float32_vs_uint8
self
.
logs_usage
=
logs_usage
def
_pixel_difference_closeness_kwargs
(
uint8_atol
,
*
,
dtype
=
torch
.
uint8
,
mae
=
False
):
...
...
@@ -675,6 +679,7 @@ KERNEL_INFOS.append(
sample_inputs_fn
=
sample_inputs_convert_format_bounding_box
,
reference_fn
=
reference_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
logs_usage
=
True
,
),
)
...
...
@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append(
KernelInfo
(
F
.
clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
logs_usage
=
True
,
)
)
...
...
test/test_prototype_transforms_functional.py
View file @
511924c1
...
...
@@ -108,6 +108,19 @@ class TestKernels:
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
)
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
KERNEL_INFOS
if
info
.
logs_usage
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs_fn
(),
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_logging
(
self
,
spy_on
,
info
,
args_kwargs
,
device
):
spy
=
spy_on
(
torch
.
_C
.
_log_api_usage_once
)
args
,
kwargs
=
args_kwargs
.
load
(
device
)
info
.
kernel
(
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
info
.
kernel
.
__module__
}
.
{
info
.
id
}
"
)
@
ignore_jit_warning_no_profile
@
sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
...
@@ -291,6 +304,19 @@ class TestDispatchers:
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
)
@
make_info_args_kwargs_parametrization
(
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_logging
(
self
,
spy_on
,
info
,
args_kwargs
,
device
):
spy
=
spy_on
(
torch
.
_C
.
_log_api_usage_once
)
args
,
kwargs
=
args_kwargs
.
load
(
device
)
info
.
dispatcher
(
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
info
.
dispatcher
.
__module__
}
.
{
info
.
id
}
"
)
@
ignore_jit_warning_no_profile
@
image_sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
...
torchvision/prototype/transforms/functional/_augment.py
View file @
511924c1
...
...
@@ -5,6 +5,7 @@ import PIL.Image
import
torch
from
torchvision.prototype
import
datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
def
erase_image_tensor
(
...
...
@@ -41,6 +42,9 @@ def erase(
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
torchvision/prototype/transforms/functional/_color.py
View file @
511924c1
...
...
@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
...
...
@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def
adjust_brightness
(
inpt
:
datapoints
.
InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def
adjust_saturation
(
inpt
:
datapoints
.
InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_saturation
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def
adjust_contrast
(
inpt
:
datapoints
.
InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def
adjust_sharpness
(
inpt
:
datapoints
.
InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_sharpness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def
adjust_hue
(
inpt
:
datapoints
.
InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def
adjust_gamma
(
inpt
:
datapoints
.
InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def
posterize
(
inpt
:
datapoints
.
InputTypeJIT
,
bits
:
int
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def
solarize
(
inpt
:
datapoints
.
InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def
autocontrast
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def
equalize
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
def
invert
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
511924c1
...
...
@@ -19,6 +19,8 @@ from torchvision.transforms.functional import (
)
from
torchvision.transforms.functional_tensor
import
_pad_symmetric
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
...
...
@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def
horizontal_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
def
vertical_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -231,6 +239,8 @@ def resize(
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -730,6 +740,9 @@ def affine(
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
...
...
@@ -913,6 +926,9 @@ def rotate(
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1120,6 +1136,9 @@ def pad(
fill
:
datapoints
.
FillTypeJIT
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def
crop
(
inpt
:
datapoints
.
InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1452,6 +1474,8 @@ def perspective(
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1612,6 +1636,9 @@ def elastic(
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def
center_crop
(
inpt
:
datapoints
.
InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1817,6 +1847,9 @@ def resized_crop(
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
def
five_crop
(
inpt
:
ImageOrVideoTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
five_crop
)
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
...
...
@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def
ten_crop
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Union
[
List
[
datapoints
.
ImageTypeJIT
],
List
[
datapoints
.
VideoTypeJIT
]]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
511924c1
...
...
@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.utils
import
_log_api_usage_once
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
...
...
@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int:
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[
def
get_spatial_size
(
inpt
:
datapoints
.
InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_spatial_size
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int:
def
get_num_frames
(
inpt
:
datapoints
.
VideoTypeJIT
)
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
...
...
@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
def
convert_format_bounding_box
(
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_box
)
if
new_format
==
old_format
:
return
bounding_box
...
...
@@ -199,6 +216,9 @@ def convert_format_bounding_box(
def
clamp_bounding_box
(
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_box
)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes
=
convert_format_bounding_box
(
...
...
@@ -313,6 +333,9 @@ def convert_color_space(
color_space
:
ColorSpace
,
old_color_space
:
Optional
[
ColorSpace
]
=
None
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def
convert_dtype
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_dtype
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
511924c1
...
...
@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad
from
torchvision.prototype
import
datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
..utils
import
is_simple_tensor
...
...
@@ -57,6 +59,8 @@ def normalize(
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
else
:
...
...
@@ -168,6 +172,9 @@ def gaussian_blur_video(
def
gaussian_blur
(
inpt
:
datapoints
.
InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
...
...
torchvision/prototype/transforms/functional/_temporal.py
View file @
511924c1
...
...
@@ -2,6 +2,8 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.utils
import
_log_api_usage_once
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
...
@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
VideoTypeJIT
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
datapoints
.
VideoTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment