Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
511924c1
Unverified
Commit
511924c1
authored
Dec 06, 2022
by
Philip Meier
Committed by
GitHub
Dec 06, 2022
Browse files
add usage logging to prototype dispatchers / kernels (#7012)
parent
c65d57a5
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
148 additions
and
0 deletions
+148
-0
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+6
-0
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+26
-0
torchvision/prototype/transforms/functional/_augment.py
torchvision/prototype/transforms/functional/_augment.py
+4
-0
torchvision/prototype/transforms/functional/_color.py
torchvision/prototype/transforms/functional/_color.py
+35
-0
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+39
-0
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+26
-0
torchvision/prototype/transforms/functional/_misc.py
torchvision/prototype/transforms/functional/_misc.py
+7
-0
torchvision/prototype/transforms/functional/_temporal.py
torchvision/prototype/transforms/functional/_temporal.py
+5
-0
No files found.
test/prototype_transforms_kernel_infos.py
View file @
511924c1
...
@@ -57,6 +57,9 @@ class KernelInfo(InfoBase):
...
@@ -57,6 +57,9 @@ class KernelInfo(InfoBase):
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
# dtype.
# dtype.
float32_vs_uint8
=
False
,
float32_vs_uint8
=
False
,
# Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
# manually. If set, triggers a test that makes sure this happens.
logs_usage
=
False
,
# See InfoBase
# See InfoBase
test_marks
=
None
,
test_marks
=
None
,
# See InfoBase
# See InfoBase
...
@@ -71,6 +74,7 @@ class KernelInfo(InfoBase):
...
@@ -71,6 +74,7 @@ class KernelInfo(InfoBase):
if
float32_vs_uint8
and
not
callable
(
float32_vs_uint8
):
if
float32_vs_uint8
and
not
callable
(
float32_vs_uint8
):
float32_vs_uint8
=
lambda
other_args
,
kwargs
:
(
other_args
,
kwargs
)
# noqa: E731
float32_vs_uint8
=
lambda
other_args
,
kwargs
:
(
other_args
,
kwargs
)
# noqa: E731
self
.
float32_vs_uint8
=
float32_vs_uint8
self
.
float32_vs_uint8
=
float32_vs_uint8
self
.
logs_usage
=
logs_usage
def
_pixel_difference_closeness_kwargs
(
uint8_atol
,
*
,
dtype
=
torch
.
uint8
,
mae
=
False
):
def
_pixel_difference_closeness_kwargs
(
uint8_atol
,
*
,
dtype
=
torch
.
uint8
,
mae
=
False
):
...
@@ -675,6 +679,7 @@ KERNEL_INFOS.append(
...
@@ -675,6 +679,7 @@ KERNEL_INFOS.append(
sample_inputs_fn
=
sample_inputs_convert_format_bounding_box
,
sample_inputs_fn
=
sample_inputs_convert_format_bounding_box
,
reference_fn
=
reference_convert_format_bounding_box
,
reference_fn
=
reference_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
logs_usage
=
True
,
),
),
)
)
...
@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append(
...
@@ -2100,6 +2105,7 @@ KERNEL_INFOS.append(
KernelInfo
(
KernelInfo
(
F
.
clamp_bounding_box
,
F
.
clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
logs_usage
=
True
,
)
)
)
)
...
...
test/test_prototype_transforms_functional.py
View file @
511924c1
...
@@ -108,6 +108,19 @@ class TestKernels:
...
@@ -108,6 +108,19 @@ class TestKernels:
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
args_kwargs_fn
=
lambda
info
:
info
.
reference_inputs_fn
(),
)
)
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
KERNEL_INFOS
if
info
.
logs_usage
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs_fn
(),
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_logging
(
self
,
spy_on
,
info
,
args_kwargs
,
device
):
spy
=
spy_on
(
torch
.
_C
.
_log_api_usage_once
)
args
,
kwargs
=
args_kwargs
.
load
(
device
)
info
.
kernel
(
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
info
.
kernel
.
__module__
}
.
{
info
.
id
}
"
)
@
ignore_jit_warning_no_profile
@
ignore_jit_warning_no_profile
@
sample_inputs
@
sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
@@ -291,6 +304,19 @@ class TestDispatchers:
...
@@ -291,6 +304,19 @@ class TestDispatchers:
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
)
)
@
make_info_args_kwargs_parametrization
(
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_logging
(
self
,
spy_on
,
info
,
args_kwargs
,
device
):
spy
=
spy_on
(
torch
.
_C
.
_log_api_usage_once
)
args
,
kwargs
=
args_kwargs
.
load
(
device
)
info
.
dispatcher
(
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
info
.
dispatcher
.
__module__
}
.
{
info
.
id
}
"
)
@
ignore_jit_warning_no_profile
@
ignore_jit_warning_no_profile
@
image_sample_inputs
@
image_sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
...
torchvision/prototype/transforms/functional/_augment.py
View file @
511924c1
...
@@ -5,6 +5,7 @@ import PIL.Image
...
@@ -5,6 +5,7 @@ import PIL.Image
import
torch
import
torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
def
erase_image_tensor
(
def
erase_image_tensor
(
...
@@ -41,6 +42,9 @@ def erase(
...
@@ -41,6 +42,9 @@ def erase(
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
):
...
...
torchvision/prototype/transforms/functional/_color.py
View file @
511924c1
...
@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints
...
@@ -5,6 +5,8 @@ from torchvision.prototype import datapoints
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
from
._meta
import
_num_value_bits
,
_rgb_to_gray
,
convert_dtype_image_tensor
...
@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
...
@@ -38,6 +40,9 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def
adjust_brightness
(
inpt
:
datapoints
.
InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
def
adjust_brightness
(
inpt
:
datapoints
.
InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
...
@@ -79,6 +84,9 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def
adjust_saturation
(
inpt
:
datapoints
.
InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
def
adjust_saturation
(
inpt
:
datapoints
.
InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_saturation
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
...
@@ -120,6 +128,9 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def
adjust_contrast
(
inpt
:
datapoints
.
InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
def
adjust_contrast
(
inpt
:
datapoints
.
InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
...
@@ -195,6 +206,9 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def
adjust_sharpness
(
inpt
:
datapoints
.
InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
def
adjust_sharpness
(
inpt
:
datapoints
.
InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_sharpness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
...
@@ -309,6 +323,9 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def
adjust_hue
(
inpt
:
datapoints
.
InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
def
adjust_hue
(
inpt
:
datapoints
.
InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
...
@@ -351,6 +368,9 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def
adjust_gamma
(
inpt
:
datapoints
.
InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
InputTypeJIT
:
def
adjust_gamma
(
inpt
:
datapoints
.
InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
...
@@ -387,6 +407,9 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def
posterize
(
inpt
:
datapoints
.
InputTypeJIT
,
bits
:
int
)
->
datapoints
.
InputTypeJIT
:
def
posterize
(
inpt
:
datapoints
.
InputTypeJIT
,
bits
:
int
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
...
@@ -417,6 +440,9 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def
solarize
(
inpt
:
datapoints
.
InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
InputTypeJIT
:
def
solarize
(
inpt
:
datapoints
.
InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -469,6 +495,9 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def
autocontrast
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
def
autocontrast
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -561,6 +590,9 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def
equalize
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
def
equalize
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -594,6 +626,9 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
def
invert
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
def
invert
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
511924c1
...
@@ -19,6 +19,8 @@ from torchvision.transforms.functional import (
...
@@ -19,6 +19,8 @@ from torchvision.transforms.functional import (
)
)
from
torchvision.transforms.functional_tensor
import
_pad_symmetric
from
torchvision.transforms.functional_tensor
import
_pad_symmetric
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
...
@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -55,6 +57,9 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def
horizontal_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
def
horizontal_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
...
@@ -103,6 +108,9 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
def
vertical_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
def
vertical_flip
(
inpt
:
datapoints
.
InputTypeJIT
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -231,6 +239,8 @@ def resize(
...
@@ -231,6 +239,8 @@ def resize(
max_size
:
Optional
[
int
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
bool
]
=
None
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -730,6 +740,9 @@ def affine(
...
@@ -730,6 +740,9 @@ def affine(
fill
:
datapoints
.
FillTypeJIT
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
# TODO: consider deprecating integers from angle and shear on the future
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
...
@@ -913,6 +926,9 @@ def rotate(
...
@@ -913,6 +926,9 @@ def rotate(
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1120,6 +1136,9 @@ def pad(
...
@@ -1120,6 +1136,9 @@ def pad(
fill
:
datapoints
.
FillTypeJIT
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
padding_mode
:
str
=
"constant"
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
...
@@ -1197,6 +1216,9 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
def
crop
(
inpt
:
datapoints
.
InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
InputTypeJIT
:
def
crop
(
inpt
:
datapoints
.
InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1452,6 +1474,8 @@ def perspective(
...
@@ -1452,6 +1474,8 @@ def perspective(
fill
:
datapoints
.
FillTypeJIT
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1612,6 +1636,9 @@ def elastic(
...
@@ -1612,6 +1636,9 @@ def elastic(
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
fill
:
datapoints
.
FillTypeJIT
=
None
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
...
@@ -1724,6 +1751,9 @@ def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tens
def
center_crop
(
inpt
:
datapoints
.
InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
InputTypeJIT
:
def
center_crop
(
inpt
:
datapoints
.
InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1817,6 +1847,9 @@ def resized_crop(
...
@@ -1817,6 +1847,9 @@ def resized_crop(
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
bool
]
=
None
,
antialias
:
Optional
[
bool
]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
...
@@ -1897,6 +1930,9 @@ ImageOrVideoTypeJIT = Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]
def
five_crop
(
def
five_crop
(
inpt
:
ImageOrVideoTypeJIT
,
size
:
List
[
int
]
inpt
:
ImageOrVideoTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
]:
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
five_crop
)
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
# `ten_crop`
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
...
@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
...
@@ -1952,6 +1988,9 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def
ten_crop
(
def
ten_crop
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
size
:
List
[
int
],
vertical_flip
:
bool
=
False
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Union
[
List
[
datapoints
.
ImageTypeJIT
],
List
[
datapoints
.
VideoTypeJIT
]]:
)
->
Union
[
List
[
datapoints
.
ImageTypeJIT
],
List
[
datapoints
.
VideoTypeJIT
]]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
):
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
511924c1
...
@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
...
@@ -7,6 +7,8 @@ from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms
import
functional_pil
as
_FP
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.transforms.functional_tensor
import
_max_value
from
torchvision.utils
import
_log_api_usage_once
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
chw
=
list
(
image
.
shape
[
-
3
:])
...
@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions
...
@@ -24,6 +26,9 @@ get_dimensions_image_pil = _FP.get_dimensions
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
List
[
int
]:
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
):
...
@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int:
...
@@ -60,6 +65,9 @@ def get_num_channels_video(video: torch.Tensor) -> int:
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
int
:
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
])
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
):
...
@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[
...
@@ -109,6 +117,9 @@ def get_spatial_size_bounding_box(bounding_box: datapoints.BoundingBox) -> List[
def
get_spatial_size
(
inpt
:
datapoints
.
InputTypeJIT
)
->
List
[
int
]:
def
get_spatial_size
(
inpt
:
datapoints
.
InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_spatial_size
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int:
...
@@ -129,6 +140,9 @@ def get_num_frames_video(video: torch.Tensor) -> int:
def
get_num_frames
(
inpt
:
datapoints
.
VideoTypeJIT
)
->
int
:
def
get_num_frames
(
inpt
:
datapoints
.
VideoTypeJIT
)
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
return
get_num_frames_video
(
inpt
)
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
elif
isinstance
(
inpt
,
datapoints
.
Video
):
...
@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
...
@@ -179,6 +193,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
def
convert_format_bounding_box
(
def
convert_format_bounding_box
(
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_box
)
if
new_format
==
old_format
:
if
new_format
==
old_format
:
return
bounding_box
return
bounding_box
...
@@ -199,6 +216,9 @@ def convert_format_bounding_box(
...
@@ -199,6 +216,9 @@ def convert_format_bounding_box(
def
clamp_bounding_box
(
def
clamp_bounding_box
(
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_box
)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes
=
convert_format_bounding_box
(
xyxy_boxes
=
convert_format_bounding_box
(
...
@@ -313,6 +333,9 @@ def convert_color_space(
...
@@ -313,6 +333,9 @@ def convert_color_space(
color_space
:
ColorSpace
,
color_space
:
ColorSpace
,
old_color_space
:
Optional
[
ColorSpace
]
=
None
,
old_color_space
:
Optional
[
ColorSpace
]
=
None
,
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
)
->
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_color_space
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
):
...
@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
...
@@ -417,6 +440,9 @@ def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -
def
convert_dtype
(
def
convert_dtype
(
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
dtype
:
torch
.
dtype
=
torch
.
float
inpt
:
Union
[
datapoints
.
ImageTypeJIT
,
datapoints
.
VideoTypeJIT
],
dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_dtype
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
))
):
):
...
...
torchvision/prototype/transforms/functional/_misc.py
View file @
511924c1
...
@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad
...
@@ -8,6 +8,8 @@ from torch.nn.functional import conv2d, pad as torch_pad
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
..utils
import
is_simple_tensor
from
..utils
import
is_simple_tensor
...
@@ -57,6 +59,8 @@ def normalize(
...
@@ -57,6 +59,8 @@ def normalize(
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
if
is_simple_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoints
.
Video
)):
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
inpt
=
inpt
.
as_subclass
(
torch
.
Tensor
)
else
:
else
:
...
@@ -168,6 +172,9 @@ def gaussian_blur_video(
...
@@ -168,6 +172,9 @@ def gaussian_blur_video(
def
gaussian_blur
(
def
gaussian_blur
(
inpt
:
datapoints
.
InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
inpt
:
datapoints
.
InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
InputTypeJIT
:
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
):
...
...
torchvision/prototype/transforms/functional/_temporal.py
View file @
511924c1
...
@@ -2,6 +2,8 @@ import torch
...
@@ -2,6 +2,8 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.utils
import
_log_api_usage_once
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
...
@@ -13,6 +15,9 @@ def uniform_temporal_subsample_video(video: torch.Tensor, num_samples: int, temp
def
uniform_temporal_subsample
(
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
VideoTypeJIT
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
inpt
:
datapoints
.
VideoTypeJIT
,
num_samples
:
int
,
temporal_dim
:
int
=
-
4
)
->
datapoints
.
VideoTypeJIT
:
)
->
datapoints
.
VideoTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Video
)):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
,
temporal_dim
=
temporal_dim
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
elif
isinstance
(
inpt
,
datapoints
.
Video
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment