Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2030d208
Unverified
Commit
2030d208
authored
Aug 07, 2023
by
Philip Meier
Committed by
GitHub
Aug 07, 2023
Browse files
register tensor and PIL kernel the same way as datapoints (#7797)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
84db2ac4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
552 additions
and
687 deletions
+552
-687
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+3
-55
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+117
-51
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+9
-16
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+103
-187
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+187
-236
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+31
-57
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+25
-44
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+8
-12
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+69
-29
No files found.
test/test_transforms_v2_functional.py
View file @
2030d208
...
...
@@ -2,7 +2,6 @@ import inspect
import
math
import
os
import
re
from
unittest
import
mock
import
numpy
as
np
import
PIL.Image
...
...
@@ -25,7 +24,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
from
torchvision.transforms.v2.functional._utils
import
_KERNEL_REGISTRY
from
torchvision.transforms.v2.utils
import
is_simple_tensor
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
...
...
@@ -359,18 +357,6 @@ class TestDispatchers:
def
test_scriptable
(
self
,
dispatcher
):
script
(
dispatcher
)
@
image_sample_inputs
def
test_dispatch_simple_tensor
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
image_simple_tensor
=
torch
.
Tensor
(
image_datapoint
)
kernel_info
=
info
.
kernel_infos
[
datapoints
.
Image
]
spy
=
spy_on
(
kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
kernel_info
.
id
)
info
.
dispatcher
(
image_simple_tensor
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
@
image_sample_inputs
def
test_simple_tensor_output_type
(
self
,
info
,
args_kwargs
):
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
...
...
@@ -381,25 +367,6 @@ class TestDispatchers:
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert
type
(
output
)
is
torch
.
Tensor
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
DISPATCHER_INFOS
if
info
.
pil_kernel_info
is
not
None
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
)
def
test_dispatch_pil
(
self
,
info
,
args_kwargs
,
spy_on
):
(
image_datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
if
image_datapoint
.
ndim
>
3
:
pytest
.
skip
(
"Input is batched"
)
image_pil
=
F
.
to_image_pil
(
image_datapoint
)
pil_kernel_info
=
info
.
pil_kernel_info
spy
=
spy_on
(
pil_kernel_info
.
kernel
,
module
=
info
.
dispatcher
.
__module__
,
name
=
pil_kernel_info
.
id
)
info
.
dispatcher
(
image_pil
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
@
make_info_args_kwargs_parametrization
(
[
info
for
info
in
DISPATCHER_INFOS
if
info
.
pil_kernel_info
is
not
None
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
Image
),
...
...
@@ -416,28 +383,6 @@ class TestDispatchers:
assert
isinstance
(
output
,
PIL
.
Image
.
Image
)
@
make_info_args_kwargs_parametrization
(
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
)
def
test_dispatch_datapoint
(
self
,
info
,
args_kwargs
,
spy_on
):
(
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
input_type
=
type
(
datapoint
)
wrapped_kernel
=
_KERNEL_REGISTRY
[
info
.
dispatcher
][
input_type
]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if
hasattr
(
wrapped_kernel
,
"__wrapped__"
):
assert
wrapped_kernel
.
__wrapped__
is
info
.
kernels
[
input_type
]
spy
=
mock
.
MagicMock
(
wraps
=
wrapped_kernel
,
name
=
wrapped_kernel
.
__name__
)
with
mock
.
patch
.
dict
(
_KERNEL_REGISTRY
[
info
.
dispatcher
],
values
=
{
input_type
:
spy
}):
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
@
make_info_args_kwargs_parametrization
(
DISPATCHER_INFOS
,
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(),
...
...
@@ -449,6 +394,9 @@ class TestDispatchers:
assert
isinstance
(
output
,
type
(
datapoint
))
if
isinstance
(
datapoint
,
datapoints
.
BoundingBoxes
)
and
info
.
dispatcher
is
not
F
.
convert_format_bounding_boxes
:
assert
output
.
format
==
datapoint
.
format
@
pytest
.
mark
.
parametrize
(
(
"dispatcher_info"
,
"datapoint_type"
,
"kernel_info"
),
[
...
...
test/test_transforms_v2_refactored.py
View file @
2030d208
...
...
@@ -39,7 +39,7 @@ from torchvision import datapoints
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional
import
pil_modes_mapping
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._utils
import
_KERNEL_REGISTRY
from
torchvision.transforms.v2.functional._utils
import
_get_kernel
,
_KERNEL_REGISTRY
,
_noop
,
_register_kernel_internal
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -173,59 +173,32 @@ def _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs):
dispatcher_scripted
(
input
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
def
_check_dispatcher_dispatch
(
dispatcher
,
kernel
,
input
,
*
args
,
**
kwargs
):
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
input_type
=
type
(
input
)
if
isinstance
(
input
,
datapoints
.
Datapoint
):
wrapped_kernel
=
_KERNEL_REGISTRY
[
dispatcher
][
input_type
]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if
hasattr
(
wrapped_kernel
,
"__wrapped__"
):
assert
wrapped_kernel
.
__wrapped__
is
kernel
spy
=
mock
.
MagicMock
(
wraps
=
wrapped_kernel
,
name
=
wrapped_kernel
.
__name__
)
with
mock
.
patch
.
dict
(
_KERNEL_REGISTRY
[
dispatcher
],
values
=
{
input_type
:
spy
}):
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
spy
.
assert_called_once
()
else
:
with
mock
.
patch
(
f
"
{
dispatcher
.
__module__
}
.
{
kernel
.
__name__
}
"
,
wraps
=
kernel
)
as
spy
:
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
spy
.
assert_called_once
()
assert
isinstance
(
output
,
input_type
)
if
isinstance
(
input
,
datapoints
.
BoundingBoxes
):
assert
output
.
format
==
input
.
format
def
check_dispatcher
(
dispatcher
,
# TODO: remove this parameter
kernel
,
input
,
*
args
,
check_scripted_smoke
=
True
,
check_dispatch
=
True
,
**
kwargs
,
):
unknown_input
=
object
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
dispatcher
(
unknown_input
,
*
args
,
**
kwargs
)
with
mock
.
patch
(
"torch._C._log_api_usage_once"
,
wraps
=
torch
.
_C
.
_log_api_usage_once
)
as
spy
:
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
dispatcher
(
unknown_input
,
*
args
,
**
kwargs
)
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
dispatcher
.
__module__
}
.
{
dispatcher
.
__name__
}
"
)
assert
isinstance
(
output
,
type
(
input
))
if
isinstance
(
input
,
datapoints
.
BoundingBoxes
):
assert
output
.
format
==
input
.
format
if
check_scripted_smoke
:
_check_dispatcher_scripted_smoke
(
dispatcher
,
input
,
*
args
,
**
kwargs
)
if
check_dispatch
:
_check_dispatcher_dispatch
(
dispatcher
,
kernel
,
input
,
*
args
,
**
kwargs
)
def
check_dispatcher_kernel_signature_match
(
dispatcher
,
*
,
kernel
,
input_type
):
"""Checks if the signature of the dispatcher matches the kernel signature."""
...
...
@@ -412,18 +385,20 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
@
pytest
.
mark
.
parametrize
(
(
"dispatcher"
,
"registered_
datapoint_cls
s"
),
(
"dispatcher"
,
"registered_
input_type
s"
),
[(
dispatcher
,
set
(
registry
.
keys
()))
for
dispatcher
,
registry
in
_KERNEL_REGISTRY
.
items
()],
)
def
test_exhaustive_kernel_registration
(
dispatcher
,
registered_
datapoint_cls
s
):
def
test_exhaustive_kernel_registration
(
dispatcher
,
registered_
input_type
s
):
missing
=
{
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
,
}
-
registered_
datapoint_cls
s
}
-
registered_
input_type
s
if
missing
:
names
=
sorted
(
f
"datapoints.
{
cls
.
__name__
}
"
for
cls
in
missing
)
names
=
sorted
(
str
(
t
)
for
t
in
missing
)
raise
AssertionError
(
"
\n
"
.
join
(
[
...
...
@@ -1753,11 +1728,6 @@ class TestToDtype:
F
.
to_dtype
,
kernel
,
make_input
(
dtype
=
input_dtype
,
device
=
device
),
# TODO: we could leave check_dispatch to True but it currently fails
# in _check_dispatcher_dispatch because there is no to_dtype() method on the datapoints.
# We should be able to put this back if we change the dispatch
# mechanism e.g. via https://github.com/pytorch/vision/pull/7733
check_dispatch
=
False
,
dtype
=
output_dtype
,
scale
=
scale
,
)
...
...
@@ -2208,9 +2178,105 @@ class TestRegisterKernel:
t
(
torch
.
rand
(
3
,
10
,
10
)).
shape
==
(
3
,
224
,
224
)
t
(
datapoints
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
def
test_bad_disaptcher_name
(
self
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
def
test_errors
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find dispatcher with name"
):
F
.
register_kernel
(
"bad_name"
,
datapoints
.
Image
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on dispatchers"
):
F
.
register_kernel
(
datapoints
.
Image
,
F
.
resize
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
F
.
register_kernel
(
F
.
resize
,
object
)
with
pytest
.
raises
(
ValueError
,
match
=
"already has a kernel registered for type"
):
F
.
register_kernel
(
F
.
resize
,
datapoints
.
Image
)(
F
.
resize_image_tensor
)
class
TestGetKernel
:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
# would also be fine
KERNELS
=
{
torch
.
Tensor
:
F
.
resize_image_tensor
,
PIL
.
Image
.
Image
:
F
.
resize_image_pil
,
datapoints
.
Image
:
F
.
resize_image_tensor
,
datapoints
.
BoundingBoxes
:
F
.
resize_bounding_boxes
,
datapoints
.
Mask
:
F
.
resize_mask
,
datapoints
.
Video
:
F
.
resize_video
,
}
def
test_unsupported_types
(
self
):
class
MyTensor
(
torch
.
Tensor
):
pass
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find dispatcher with name"
):
F
.
register_kernel
(
"bad_name"
,
CustomDatapoint
)
class
MyPILImage
(
PIL
.
Image
.
Image
):
pass
for
input_type
in
[
str
,
int
,
object
,
MyTensor
,
MyPILImage
]:
with
pytest
.
raises
(
TypeError
,
match
=
(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
_get_kernel
(
F
.
resize
,
input_type
)
def
test_exact_match
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check the exact matching afterwards.
def
resize_with_pure_kernels
():
pass
for
input_type
,
kernel
in
self
.
KERNELS
.
items
():
_register_kernel_internal
(
resize_with_pure_kernels
,
input_type
,
datapoint_wrapper
=
False
)(
kernel
)
assert
_get_kernel
(
resize_with_pure_kernels
,
input_type
)
is
kernel
def
test_builtin_datapoint_subclass
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize dispatcher
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
def
resize_with_pure_kernels
():
pass
class
MyImage
(
datapoints
.
Image
):
pass
class
MyBoundingBoxes
(
datapoints
.
BoundingBoxes
):
pass
class
MyMask
(
datapoints
.
Mask
):
pass
class
MyVideo
(
datapoints
.
Video
):
pass
for
custom_datapoint_subclass
in
[
MyImage
,
MyBoundingBoxes
,
MyMask
,
MyVideo
,
]:
builtin_datapoint_class
=
custom_datapoint_subclass
.
__mro__
[
1
]
builtin_datapoint_kernel
=
self
.
KERNELS
[
builtin_datapoint_class
]
_register_kernel_internal
(
resize_with_pure_kernels
,
builtin_datapoint_class
,
datapoint_wrapper
=
False
)(
builtin_datapoint_kernel
)
assert
_get_kernel
(
resize_with_pure_kernels
,
custom_datapoint_subclass
)
is
builtin_datapoint_kernel
def
test_datapoint_subclass
(
self
):
class
MyDatapoint
(
datapoints
.
Datapoint
):
pass
# Note that this will be an error in the future
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
_noop
def
resize_my_datapoint
():
pass
_register_kernel_internal
(
F
.
resize
,
MyDatapoint
,
datapoint_wrapper
=
False
)(
resize_my_datapoint
)
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
resize_my_datapoint
torchvision/transforms/v2/functional/_augment.py
View file @
2030d208
...
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
...
...
@@ -20,23 +20,16 @@ def erase(
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
erase
,
type
(
inpt
))
return
kernel
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
erase_image_pil
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
erase
)
kernel
=
_get_kernel
(
erase
,
type
(
inpt
))
return
kernel
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
@
_register_kernel_internal
(
erase
,
torch
.
Tensor
)
@
_register_kernel_internal
(
erase
,
datapoints
.
Image
)
def
erase_image_tensor
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
...
...
@@ -48,7 +41,7 @@ def erase_image_tensor(
return
image
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
erase
,
PIL
.
Image
.
Image
)
def
erase_image_pil
(
image
:
PIL
.
Image
.
Image
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
PIL
.
Image
.
Image
:
...
...
torchvision/transforms/v2/functional/_color.py
View file @
2030d208
...
...
@@ -10,29 +10,20 @@ from torchvision.transforms._functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rgb_to_grayscale
)
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
rgb_to_grayscale
,
type
(
inpt
))
return
kernel
(
inpt
,
num_output_channels
=
num_output_channels
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
rgb_to_grayscale_image_pil
(
inpt
,
num_output_channels
=
num_output_channels
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
rgb_to_grayscale
)
kernel
=
_get_kernel
(
rgb_to_grayscale
,
type
(
inpt
))
return
kernel
(
inpt
,
num_output_channels
=
num_output_channels
)
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a
...
...
@@ -56,12 +47,19 @@ def _rgb_to_grayscale_image_tensor(
return
l_img
@
_register_kernel_internal
(
rgb_to_grayscale
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rgb_to_grayscale
,
datapoints
.
Image
)
def
rgb_to_grayscale_image_tensor
(
image
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
return
_rgb_to_grayscale_image_tensor
(
image
,
num_output_channels
=
num_output_channels
,
preserve_dtype
=
True
)
rgb_to_grayscale_image_pil
=
_FP
.
to_grayscale
@
_register_kernel_internal
(
rgb_to_grayscale
,
PIL
.
Image
.
Image
)
def
rgb_to_grayscale_image_pil
(
image
:
PIL
.
Image
.
Image
,
num_output_channels
:
int
=
1
)
->
PIL
.
Image
.
Image
:
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
return
_FP
.
to_grayscale
(
image
,
num_output_channels
=
num_output_channels
)
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
...
...
@@ -74,23 +72,16 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
datapoints
.
_InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_brightness
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_brightness
,
type
(
inpt
))
return
kernel
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_brightness_image_pil
(
inpt
,
brightness_factor
=
brightness_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
adjust_brightness
)
kernel
=
_get_kernel
(
adjust_brightness
,
type
(
inpt
))
return
kernel
(
inpt
,
brightness_factor
=
brightness_factor
)
@
_register_kernel_internal
(
adjust_brightness
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoints
.
Image
)
def
adjust_brightness_image_tensor
(
image
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
brightness_factor
<
0
:
...
...
@@ -106,6 +97,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
return
output
if
fp
else
output
.
to
(
image
.
dtype
)
@
_register_kernel_internal
(
adjust_brightness
,
PIL
.
Image
.
Image
)
def
adjust_brightness_image_pil
(
image
:
PIL
.
Image
.
Image
,
brightness_factor
:
float
)
->
PIL
.
Image
.
Image
:
return
_FP
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
...
...
@@ -117,23 +109,16 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
datapoints
.
_InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_saturation
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)):
if
torch
.
jit
.
is_scripting
():
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_saturation
,
type
(
inpt
))
return
kernel
(
inpt
,
saturation_factor
=
saturation_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_saturation_image_pil
(
inpt
,
saturation_factor
=
saturation_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
adjust_saturation
)
kernel
=
_get_kernel
(
adjust_saturation
,
type
(
inpt
))
return
kernel
(
inpt
,
saturation_factor
=
saturation_factor
)
@
_register_kernel_internal
(
adjust_saturation
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoints
.
Image
)
def
adjust_saturation_image_tensor
(
image
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
saturation_factor
<
0
:
...
...
@@ -153,7 +138,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
return
_blend
(
image
,
grayscale_image
,
saturation_factor
)
adjust_saturation_image_pil
=
_FP
.
adjust_saturation
adjust_saturation_image_pil
=
_register_kernel_internal
(
adjust_saturation
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_saturation
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoints
.
Video
)
...
...
@@ -163,23 +148,16 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
datapoints
.
_InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_contrast
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_contrast
,
type
(
inpt
))
return
kernel
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_contrast_image_pil
(
inpt
,
contrast_factor
=
contrast_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
adjust_contrast
)
kernel
=
_get_kernel
(
adjust_contrast
,
type
(
inpt
))
return
kernel
(
inpt
,
contrast_factor
=
contrast_factor
)
@
_register_kernel_internal
(
adjust_contrast
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoints
.
Image
)
def
adjust_contrast_image_tensor
(
image
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
contrast_factor
<
0
:
...
...
@@ -199,7 +177,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
return
_blend
(
image
,
mean
,
contrast_factor
)
adjust_contrast_image_pil
=
_FP
.
adjust_contrast
adjust_contrast_image_pil
=
_register_kernel_internal
(
adjust_contrast
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_contrast
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoints
.
Video
)
...
...
@@ -209,23 +187,16 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
datapoints
.
_InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_sharpness
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)):
if
torch
.
jit
.
is_scripting
():
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_sharpness
,
type
(
inpt
))
return
kernel
(
inpt
,
sharpness_factor
=
sharpness_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_sharpness_image_pil
(
inpt
,
sharpness_factor
=
sharpness_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
adjust_sharpness
)
kernel
=
_get_kernel
(
adjust_sharpness
,
type
(
inpt
))
return
kernel
(
inpt
,
sharpness_factor
=
sharpness_factor
)
@
_register_kernel_internal
(
adjust_sharpness
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoints
.
Image
)
def
adjust_sharpness_image_tensor
(
image
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
num_channels
,
height
,
width
=
image
.
shape
[
-
3
:]
...
...
@@ -279,7 +250,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
return
output
adjust_sharpness_image_pil
=
_FP
.
adjust_sharpness
adjust_sharpness_image_pil
=
_register_kernel_internal
(
adjust_sharpness
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_sharpness
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoints
.
Video
)
...
...
@@ -289,21 +260,13 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
datapoints
.
_InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_hue
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_hue
,
type
(
inpt
))
return
kernel
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_hue_image_pil
(
inpt
,
hue_factor
=
hue_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
adjust_hue
)
kernel
=
_get_kernel
(
adjust_hue
,
type
(
inpt
))
return
kernel
(
inpt
,
hue_factor
=
hue_factor
)
def
_rgb_to_hsv
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -370,6 +333,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
return
(
a4
.
mul_
(
mask
.
unsqueeze
(
dim
=-
4
))).
sum
(
dim
=-
3
)
@
_register_kernel_internal
(
adjust_hue
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_hue
,
datapoints
.
Image
)
def
adjust_hue_image_tensor
(
image
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
...
...
@@ -398,7 +362,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return
to_dtype_image_tensor
(
image_hue_adj
,
orig_dtype
,
scale
=
True
)
adjust_hue_image_pil
=
_FP
.
adjust_hue
adjust_hue_image_pil
=
_register_kernel_internal
(
adjust_hue
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_hue
)
@
_register_kernel_internal
(
adjust_hue
,
datapoints
.
Video
)
...
...
@@ -408,23 +372,16 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
datapoints
.
_InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_gamma
,
type
(
inpt
))
return
kernel
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_gamma_image_pil
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
adjust_gamma
)
kernel
=
_get_kernel
(
adjust_gamma
,
type
(
inpt
))
return
kernel
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
@
_register_kernel_internal
(
adjust_gamma
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoints
.
Image
)
def
adjust_gamma_image_tensor
(
image
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1.0
)
->
torch
.
Tensor
:
if
gamma
<
0
:
...
...
@@ -445,7 +402,7 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
return
to_dtype_image_tensor
(
output
,
image
.
dtype
,
scale
=
True
)
adjust_gamma_image_pil
=
_FP
.
adjust_gamma
adjust_gamma_image_pil
=
_register_kernel_internal
(
adjust_gamma
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_gamma
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoints
.
Video
)
...
...
@@ -455,23 +412,16 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
datapoints
.
_InputTypeJIT
,
bits
:
int
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
posterize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
posterize
,
type
(
inpt
))
return
kernel
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
posterize_image_pil
(
inpt
,
bits
=
bits
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
posterize
)
kernel
=
_get_kernel
(
posterize
,
type
(
inpt
))
return
kernel
(
inpt
,
bits
=
bits
)
@
_register_kernel_internal
(
posterize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
posterize
,
datapoints
.
Image
)
def
posterize_image_tensor
(
image
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
...
...
@@ -486,7 +436,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
return
image
&
mask
posterize_image_pil
=
_FP
.
posterize
posterize_image_pil
=
_register_kernel_internal
(
posterize
,
PIL
.
Image
.
Image
)(
_FP
.
posterize
)
@
_register_kernel_internal
(
posterize
,
datapoints
.
Video
)
...
...
@@ -496,23 +446,16 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
datapoints
.
_InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
solarize
,
type
(
inpt
))
return
kernel
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
solarize_image_pil
(
inpt
,
threshold
=
threshold
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
solarize
)
kernel
=
_get_kernel
(
solarize
,
type
(
inpt
))
return
kernel
(
inpt
,
threshold
=
threshold
)
@
_register_kernel_internal
(
solarize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
solarize
,
datapoints
.
Image
)
def
solarize_image_tensor
(
image
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
threshold
>
_max_value
(
image
.
dtype
):
...
...
@@ -521,7 +464,7 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor
return
torch
.
where
(
image
>=
threshold
,
invert_image_tensor
(
image
),
image
)
solarize_image_pil
=
_FP
.
solarize
solarize_image_pil
=
_register_kernel_internal
(
solarize
,
PIL
.
Image
.
Image
)(
_FP
.
solarize
)
@
_register_kernel_internal
(
solarize
,
datapoints
.
Video
)
...
...
@@ -531,25 +474,16 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
autocontrast_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
autocontrast
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
autocontrast_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
autocontrast
)
kernel
=
_get_kernel
(
autocontrast
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
autocontrast
,
torch
.
Tensor
)
@
_register_kernel_internal
(
autocontrast
,
datapoints
.
Image
)
def
autocontrast_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
c
=
image
.
shape
[
-
3
]
...
...
@@ -580,7 +514,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
return
diff
.
div_
(
inv_scale
).
clamp_
(
0
,
bound
).
to
(
image
.
dtype
)
autocontrast_image_pil
=
_FP
.
autocontrast
autocontrast_image_pil
=
_register_kernel_internal
(
autocontrast
,
PIL
.
Image
.
Image
)(
_FP
.
autocontrast
)
@
_register_kernel_internal
(
autocontrast
,
datapoints
.
Video
)
...
...
@@ -590,25 +524,16 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
equalize_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
equalize
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
equalize_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
equalize
)
kernel
=
_get_kernel
(
equalize
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
equalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
equalize
,
datapoints
.
Image
)
def
equalize_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
numel
()
==
0
:
...
...
@@ -679,7 +604,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
return
to_dtype_image_tensor
(
output
,
output_dtype
,
scale
=
True
)
equalize_image_pil
=
_FP
.
equalize
equalize_image_pil
=
_register_kernel_internal
(
equalize
,
PIL
.
Image
.
Image
)(
_FP
.
equalize
)
@
_register_kernel_internal
(
equalize
,
datapoints
.
Video
)
...
...
@@ -689,25 +614,16 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
invert_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
invert
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
invert_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
invert
)
kernel
=
_get_kernel
(
invert
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
invert
,
torch
.
Tensor
)
@
_register_kernel_internal
(
invert
,
datapoints
.
Image
)
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
...
...
@@ -719,7 +635,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
return
image
.
bitwise_xor
((
1
<<
_num_value_bits
(
image
.
dtype
))
-
1
)
invert_image_pil
=
_FP
.
invert
invert_image_pil
=
_register_kernel_internal
(
invert
,
PIL
.
Image
.
Image
)(
_FP
.
invert
)
@
_register_kernel_internal
(
invert
,
datapoints
.
Video
)
...
...
torchvision/transforms/v2/functional/_geometry.py
View file @
2030d208
...
...
@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
(
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
is_simple_tensor
,
)
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
...
@@ -46,30 +40,22 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
def
horizontal_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
horizontal_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
horizontal_flip
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
horizontal_flip_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
horizontal_flip
)
kernel
=
_get_kernel
(
horizontal_flip
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
horizontal_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Image
)
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
@
_register_kernel_internal
(
horizontal_flip
,
PIL
.
Image
.
Image
)
def
horizontal_flip_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
PIL
.
Image
.
Image
:
return
_FP
.
hflip
(
image
)
...
...
@@ -110,30 +96,22 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
def
vertical_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
vertical_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
vertical_flip
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
vertical_flip_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
vertical_flip
)
kernel
=
_get_kernel
(
vertical_flip
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
vertical_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Image
)
def
vertical_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
2
)
@
_register_kernel_internal
(
vertical_flip
,
PIL
.
Image
.
Image
)
def
vertical_flip_image_pil
(
image
:
PIL
.
Image
)
->
PIL
.
Image
:
return
_FP
.
vflip
(
image
)
...
...
@@ -199,24 +177,16 @@ def resize(
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
resize
,
type
(
inpt
))
return
kernel
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resize_image_pil
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
resize_image_tensor
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
_log_api_usage_once
(
resize
)
kernel
=
_get_kernel
(
resize
,
type
(
inpt
))
return
kernel
(
inpt
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
@
_register_kernel_internal
(
resize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resize
,
datapoints
.
Image
)
def
resize_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -297,7 +267,6 @@ def resize_image_tensor(
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
@
torch
.
jit
.
unused
def
resize_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
Union
[
Sequence
[
int
],
int
],
...
...
@@ -319,6 +288,19 @@ def resize_image_pil(
return
image
.
resize
((
new_width
,
new_height
),
resample
=
pil_modes_mapping
[
interpolation
])
@
_register_kernel_internal
(
resize
,
PIL
.
Image
.
Image
)
def
_resize_image_pil_dispatch
(
image
:
PIL
.
Image
.
Image
,
size
:
Union
[
Sequence
[
int
],
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
PIL
.
Image
.
Image
:
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resize_image_pil
(
image
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
def
resize_mask
(
mask
:
torch
.
Tensor
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -391,26 +373,10 @@ def affine(
fill
:
datapoints
.
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
affine_image_tensor
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
affine
,
type
(
inpt
))
return
kernel
(
inpt
,
angle
,
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
...
...
@@ -418,22 +384,20 @@ def affine(
fill
=
fill
,
center
=
center
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
affine_image_pil
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
affine
)
kernel
=
_get_kernel
(
affine
,
type
(
inpt
))
return
kernel
(
inpt
,
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
def
_affine_parse_args
(
...
...
@@ -684,6 +648,7 @@ def _affine_grid(
return
output_grid
.
view
(
1
,
oh
,
ow
,
2
)
@
_register_kernel_internal
(
affine
,
torch
.
Tensor
)
@
_register_kernel_internal
(
affine
,
datapoints
.
Image
)
def
affine_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -736,7 +701,7 @@ def affine_image_tensor(
return
output
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
affine
,
PIL
.
Image
.
Image
)
def
affine_image_pil
(
image
:
PIL
.
Image
.
Image
,
angle
:
Union
[
int
,
float
],
...
...
@@ -983,23 +948,18 @@ def rotate(
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
rotate
,
type
(
inpt
))
return
kernel
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
rotate_image_pil
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
if
torch
.
jit
.
is_scripting
():
return
rotate_image_tensor
(
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
_log_api_usage_once
(
rotate
)
kernel
=
_get_kernel
(
rotate
,
type
(
inpt
))
return
kernel
(
inpt
,
angle
=
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
@
_register_kernel_internal
(
rotate
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rotate
,
datapoints
.
Image
)
def
rotate_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -1045,7 +1005,7 @@ def rotate_image_tensor(
return
output
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
rotate
,
PIL
.
Image
.
Image
)
def
rotate_image_pil
(
image
:
PIL
.
Image
.
Image
,
angle
:
float
,
...
...
@@ -1162,22 +1122,13 @@ def pad(
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
if
torch
.
jit
.
is_scripting
():
return
pad_image_tensor
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
_log_api_usage_once
(
pad
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
pad
,
type
(
inpt
))
return
kernel
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
pad_image_pil
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
kernel
=
_get_kernel
(
pad
,
type
(
inpt
))
return
kernel
(
inpt
,
padding
=
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
def
_parse_pad_padding
(
padding
:
Union
[
int
,
List
[
int
]])
->
List
[
int
]:
...
...
@@ -1204,6 +1155,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
return
[
pad_left
,
pad_right
,
pad_top
,
pad_bottom
]
@
_register_kernel_internal
(
pad
,
torch
.
Tensor
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Image
)
def
pad_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -1303,7 +1255,7 @@ def _pad_with_vector_fill(
return
output
pad_image_pil
=
_FP
.
pad
pad_image_pil
=
_register_kernel_internal
(
pad
,
PIL
.
Image
.
Image
)(
_FP
.
pad
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Mask
)
...
...
@@ -1385,23 +1337,16 @@ def pad_video(
def
crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
crop
,
type
(
inpt
))
return
kernel
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
crop_image_pil
(
inpt
,
top
,
left
,
height
,
width
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
crop_image_tensor
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
_log_api_usage_once
(
crop
)
kernel
=
_get_kernel
(
crop
,
type
(
inpt
))
return
kernel
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
@
_register_kernel_internal
(
crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
crop
,
datapoints
.
Image
)
def
crop_image_tensor
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
h
,
w
=
image
.
shape
[
-
2
:]
...
...
@@ -1422,6 +1367,7 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid
crop_image_pil
=
_FP
.
crop
_register_kernel_internal
(
crop
,
PIL
.
Image
.
Image
)(
crop_image_pil
)
def
crop_bounding_boxes
(
...
...
@@ -1484,25 +1430,28 @@ def perspective(
fill
:
datapoints
.
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
perspective_image_tensor
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
perspective
,
type
(
inpt
))
return
kernel
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
perspective_image_pil
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
inpt
,
startpoints
=
startpoints
,
endpoints
=
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
,
)
_log_api_usage_once
(
perspective
)
kernel
=
_get_kernel
(
perspective
,
type
(
inpt
))
return
kernel
(
inpt
,
startpoints
=
startpoints
,
endpoints
=
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
,
)
def
_perspective_grid
(
coeffs
:
List
[
float
],
ow
:
int
,
oh
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
...
...
@@ -1551,6 +1500,7 @@ def _perspective_coefficients(
raise
ValueError
(
"Either the startpoints/endpoints or the coefficients must have non `None` values."
)
@
_register_kernel_internal
(
perspective
,
torch
.
Tensor
)
@
_register_kernel_internal
(
perspective
,
datapoints
.
Image
)
def
perspective_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -1598,7 +1548,7 @@ def perspective_image_tensor(
return
output
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
perspective
,
PIL
.
Image
.
Image
)
def
perspective_image_pil
(
image
:
PIL
.
Image
.
Image
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1787,29 +1737,19 @@ def elastic(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
if
not
isinstance
(
displacement
,
torch
.
Tensor
):
raise
TypeError
(
"Argument displacement should be a Tensor"
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
elastic
,
type
(
inpt
))
return
kernel
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
elastic_image_pil
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
elastic_image_tensor
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
_log_api_usage_once
(
elastic
)
kernel
=
_get_kernel
(
elastic
,
type
(
inpt
))
return
kernel
(
inpt
,
displacement
=
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elastic_transform
=
elastic
@
_register_kernel_internal
(
elastic
,
torch
.
Tensor
)
@
_register_kernel_internal
(
elastic
,
datapoints
.
Image
)
def
elastic_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -1867,7 +1807,7 @@ def elastic_image_tensor(
return
output
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
elastic
,
PIL
.
Image
.
Image
)
def
elastic_image_pil
(
image
:
PIL
.
Image
.
Image
,
displacement
:
torch
.
Tensor
,
...
...
@@ -1990,21 +1930,13 @@ def elastic_video(
def
center_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
center_crop_image_tensor
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
center_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
center_crop_image_pil
(
inpt
,
output_size
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
center_crop_image_tensor
(
inpt
,
output_size
=
output_size
)
_log_api_usage_once
(
center_crop
)
kernel
=
_get_kernel
(
center_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
output_size
=
output_size
)
def
_center_crop_parse_output_size
(
output_size
:
List
[
int
])
->
List
[
int
]:
...
...
@@ -2034,6 +1966,7 @@ def _center_crop_compute_crop_anchor(
return
crop_top
,
crop_left
@
_register_kernel_internal
(
center_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Image
)
def
center_crop_image_tensor
(
image
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
...
...
@@ -2054,7 +1987,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
return
image
[...,
crop_top
:
(
crop_top
+
crop_height
),
crop_left
:
(
crop_left
+
crop_width
)]
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
center_crop
,
PIL
.
Image
.
Image
)
def
center_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
output_size
:
List
[
int
])
->
PIL
.
Image
.
Image
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
image_height
,
image_width
=
get_size_image_pil
(
image
)
...
...
@@ -2125,25 +2058,34 @@ def resized_crop(
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
resized_crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
resized_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
resized_crop_image_pil
(
inpt
,
top
,
left
,
height
,
width
,
size
=
size
,
interpolation
=
interpolation
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
,
interpolation
=
interpolation
,
antialias
=
antialias
,
)
_log_api_usage_once
(
resized_crop
)
kernel
=
_get_kernel
(
resized_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
,
interpolation
=
interpolation
,
antialias
=
antialias
,
)
@
_register_kernel_internal
(
resized_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Image
)
def
resized_crop_image_tensor
(
image
:
torch
.
Tensor
,
...
...
@@ -2159,7 +2101,6 @@ def resized_crop_image_tensor(
return
resize_image_tensor
(
image
,
size
,
interpolation
=
interpolation
,
antialias
=
antialias
)
@
torch
.
jit
.
unused
def
resized_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
top
:
int
,
...
...
@@ -2173,6 +2114,30 @@ def resized_crop_image_pil(
return
resize_image_pil
(
image
,
size
,
interpolation
=
interpolation
)
@
_register_kernel_internal
(
resized_crop
,
PIL
.
Image
.
Image
)
def
resized_crop_image_pil_dispatch
(
image
:
PIL
.
Image
.
Image
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
PIL
.
Image
.
Image
:
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resized_crop_image_pil
(
image
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
,
interpolation
=
interpolation
,
)
def
resized_crop_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
...
...
@@ -2244,21 +2209,13 @@ def five_crop(
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
five_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
five_crop_image_tensor
(
inpt
,
size
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
five_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
five_crop_image_pil
(
inpt
,
size
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
five_crop_image_tensor
(
inpt
,
size
=
size
)
_log_api_usage_once
(
five_crop
)
kernel
=
_get_kernel
(
five_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
size
=
size
)
def
_parse_five_crop_size
(
size
:
List
[
int
])
->
List
[
int
]:
...
...
@@ -2275,6 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return
size
@
_register_five_ten_crop_kernel
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Image
)
def
five_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
]
...
...
@@ -2294,7 +2252,7 @@ def five_crop_image_tensor(
return
tl
,
tr
,
bl
,
br
,
center
@
torch
.
jit
.
unused
@
_register_five_ten_crop_kernel
(
five_crop
,
PIL
.
Image
.
Image
)
def
five_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
]
)
->
Tuple
[
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
]:
...
...
@@ -2335,23 +2293,16 @@ def ten_crop(
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
ten_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
ten_crop_image_pil
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
ten_crop_image_tensor
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
_log_api_usage_once
(
ten_crop
)
kernel
=
_get_kernel
(
ten_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
size
=
size
,
vertical_flip
=
vertical_flip
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Image
)
def
ten_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
...
...
@@ -2379,7 +2330,7 @@ def ten_crop_image_tensor(
return
non_flipped
+
flipped
@
torch
.
jit
.
unused
@
_register_five_ten_crop_kernel
(
ten_crop
,
PIL
.
Image
.
Image
)
def
ten_crop_image_pil
(
image
:
PIL
.
Image
.
Image
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
2030d208
...
...
@@ -13,23 +13,16 @@ from ._utils import _get_kernel, _register_kernel_internal, _register_unsupporte
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_dimensions
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_dimensions
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
get_dimensions_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_dimensions
,
type
(
inpt
))
return
kernel
(
inpt
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_dimensions_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
get_dimensions
)
kernel
=
_get_kernel
(
get_dimensions
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
get_dimensions
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_dimensions_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
...
...
@@ -43,7 +36,7 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
get_dimensions_image_pil
=
_FP
.
get_dimensions
get_dimensions_image_pil
=
_register_kernel_internal
(
get_dimensions
,
PIL
.
Image
.
Image
)(
_FP
.
get_dimensions
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
...
...
@@ -53,23 +46,16 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
@
_register_unsupported_type
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_channels
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
])
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_channels
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
get_num_channels_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_num_channels
,
type
(
inpt
))
return
kernel
(
inpt
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_num_channels_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
get_num_channels
)
kernel
=
_get_kernel
(
get_num_channels
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
get_num_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_num_channels_image_tensor
(
image
:
torch
.
Tensor
)
->
int
:
chw
=
image
.
shape
[
-
3
:]
...
...
@@ -82,7 +68,7 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int:
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
get_num_channels_image_pil
=
_FP
.
get_image_num_channels
get_num_channels_image_pil
=
_register_kernel_internal
(
get_num_channels
,
PIL
.
Image
.
Image
)(
_FP
.
get_image_num_channels
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
...
...
@@ -96,23 +82,16 @@ get_image_num_channels = get_num_channels
def
get_size
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
List
[
int
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_size
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
get_size_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_size
,
type
(
inpt
))
return
kernel
(
inpt
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
get_size_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
get_size
)
kernel
=
_get_kernel
(
get_size
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
get_size
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Image
,
datapoint_wrapper
=
False
)
def
get_size_image_tensor
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
hw
=
list
(
image
.
shape
[
-
2
:])
...
...
@@ -123,7 +102,7 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]:
raise
TypeError
(
f
"Input tensor should have at least two dimensions, but got
{
ndims
}
"
)
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
get_size
,
PIL
.
Image
.
Image
)
def
get_size_image_pil
(
image
:
PIL
.
Image
.
Image
)
->
List
[
int
]:
width
,
height
=
_FP
.
get_image_size
(
image
)
return
[
height
,
width
]
...
...
@@ -146,21 +125,16 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
@
_register_unsupported_type
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
get_num_frames
(
inpt
:
datapoints
.
_VideoTypeJIT
)
->
int
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
get_num_frames
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
get_num_frames_video
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
get_num_frames
,
type
(
inpt
))
return
kernel
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
get_num_frames
)
kernel
=
_get_kernel
(
get_num_frames
,
type
(
inpt
))
return
kernel
(
inpt
)
@
_register_kernel_internal
(
get_num_frames
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_num_frames
,
datapoints
.
Video
,
datapoint_wrapper
=
False
)
def
get_num_frames_video
(
video
:
torch
.
Tensor
)
->
int
:
return
video
.
shape
[
-
4
]
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
2030d208
...
...
@@ -11,13 +11,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
(
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
_register_unsupported_type
,
is_simple_tensor
,
)
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
_register_unsupported_type
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
...
...
@@ -28,19 +22,16 @@ def normalize(
std
:
List
[
float
],
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
normalize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
normalize_image_tensor
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
normalize
,
type
(
inpt
))
return
kernel
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or any TorchVision datapoint, but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
normalize
)
kernel
=
_get_kernel
(
normalize
,
type
(
inpt
))
return
kernel
(
inpt
,
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
@
_register_kernel_internal
(
normalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
normalize
,
datapoints
.
Image
)
def
normalize_image_tensor
(
image
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
...
...
@@ -86,21 +77,13 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def
gaussian_blur
(
inpt
:
datapoints
.
_InputTypeJIT
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
gaussian_blur
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
torch
.
jit
.
is_scripting
():
return
gaussian_blur_image_tensor
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
gaussian_blur
,
type
(
inpt
))
return
kernel
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
gaussian_blur_image_pil
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
gaussian_blur
)
kernel
=
_get_kernel
(
gaussian_blur
,
type
(
inpt
))
return
kernel
(
inpt
,
kernel_size
=
kernel_size
,
sigma
=
sigma
)
def
_get_gaussian_kernel1d
(
kernel_size
:
int
,
sigma
:
float
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
...
...
@@ -119,6 +102,7 @@ def _get_gaussian_kernel2d(
return
kernel2d
@
_register_kernel_internal
(
gaussian_blur
,
torch
.
Tensor
)
@
_register_kernel_internal
(
gaussian_blur
,
datapoints
.
Image
)
def
gaussian_blur_image_tensor
(
image
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
...
...
@@ -184,7 +168,7 @@ def gaussian_blur_image_tensor(
return
output
@
torch
.
jit
.
unused
@
_register_kernel_internal
(
gaussian_blur
,
PIL
.
Image
.
Image
)
def
gaussian_blur_image_pil
(
image
:
PIL
.
Image
.
Image
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
PIL
.
Image
.
Image
:
...
...
@@ -200,21 +184,17 @@ def gaussian_blur_video(
return
gaussian_blur_image_tensor
(
video
,
kernel_size
,
sigma
)
@
_register_unsupported_type
(
PIL
.
Image
.
Image
)
def
to_dtype
(
inpt
:
datapoints
.
_InputTypeJIT
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
to_dtype
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
to_dtype_image_tensor
(
inpt
,
dtype
,
scale
=
scale
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
to_dtype
,
type
(
inpt
))
return
kernel
(
inpt
,
dtype
,
scale
=
scale
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or any TorchVision datapoint, but got
{
type
(
inpt
)
}
instead."
)
if
torch
.
jit
.
is_scripting
():
return
to_dtype_image_tensor
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
_log_api_usage_once
(
to_dtype
)
kernel
=
_get_kernel
(
to_dtype
,
type
(
inpt
))
return
kernel
(
inpt
,
dtype
=
dtype
,
scale
=
scale
)
def
_num_value_bits
(
dtype
:
torch
.
dtype
)
->
int
:
...
...
@@ -232,6 +212,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
raise
TypeError
(
f
"Number of value bits is only defined for integer dtypes, but got
{
dtype
}
."
)
@
_register_kernel_internal
(
to_dtype
,
torch
.
Tensor
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Image
)
def
to_dtype_image_tensor
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
torchvision/transforms/v2/functional/_temporal.py
View file @
2030d208
...
...
@@ -5,27 +5,23 @@ from torchvision import datapoints
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
@
_register_explicit_noop
(
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
uniform_temporal_subsample
(
inpt
:
datapoints
.
_VideoTypeJIT
,
num_samples
:
int
)
->
datapoints
.
_VideoTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
uniform_temporal_subsample
)
if
torch
.
jit
.
is_scripting
():
return
uniform_temporal_subsample
_video
(
inpt
,
num_samples
=
num_samples
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
uniform_temporal_subsample_video
(
inpt
,
num_samples
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
uniform_temporal_subsample
,
type
(
inpt
))
return
kernel
(
inpt
,
num_samples
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or any TorchVision datapoint, but got
{
type
(
inpt
)
}
instead."
)
_log_api_usage_once
(
uniform_temporal_subsample
)
kernel
=
_get_kernel
(
uniform_temporal_subsample
,
type
(
inpt
))
return
kernel
(
inpt
,
num_samples
=
num_samples
)
@
_register_kernel_internal
(
uniform_temporal_subsample
,
torch
.
Tensor
)
@
_register_kernel_internal
(
uniform_temporal_subsample
,
datapoints
.
Video
)
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
2030d208
...
...
@@ -23,15 +23,17 @@ def _kernel_datapoint_wrapper(kernel):
return
wrapper
def
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
*
,
datapoint_wrapper
=
True
):
def
_register_kernel_internal
(
dispatcher
,
input_type
,
*
,
datapoint_wrapper
=
True
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
if
datapoint_cls
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
.
__name__
}
' already has a kernel registered for type '
{
datapoint_cls
.
__name__
}
'."
)
if
input_type
in
registry
:
raise
ValueError
(
f
"Dispatcher
{
dispatcher
}
already has a kernel registered for type
{
input_type
}
."
)
def
decorator
(
kernel
):
registry
[
datapoint_cls
]
=
_kernel_datapoint_wrapper
(
kernel
)
if
datapoint_wrapper
else
kernel
registry
[
input_type
]
=
(
_kernel_datapoint_wrapper
(
kernel
)
if
issubclass
(
input_type
,
datapoints
.
Datapoint
)
and
datapoint_wrapper
else
kernel
)
return
kernel
return
decorator
...
...
@@ -43,7 +45,9 @@ def _name_to_dispatcher(name):
try
:
return
getattr
(
torchvision
.
transforms
.
v2
.
functional
,
name
)
except
AttributeError
:
raise
ValueError
(
f
"Could not find dispatcher with name '
{
name
}
'."
)
from
None
raise
ValueError
(
f
"Could not find dispatcher with name '
{
name
}
' in torchvision.transforms.v2.functional."
)
from
None
def
register_kernel
(
dispatcher
,
datapoint_cls
):
...
...
@@ -54,22 +58,57 @@ def register_kernel(dispatcher, datapoint_cls):
"""
if
isinstance
(
dispatcher
,
str
):
dispatcher
=
_name_to_dispatcher
(
name
=
dispatcher
)
elif
not
(
callable
(
dispatcher
)
and
getattr
(
dispatcher
,
"__module__"
,
""
).
startswith
(
"torchvision.transforms.v2.functional"
)
):
raise
ValueError
(
f
"Kernels can only be registered on dispatchers from the torchvision.transforms.v2.functional namespace, "
f
"but got
{
dispatcher
}
."
)
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)
and
datapoint_cls
is
not
datapoints
.
Datapoint
):
raise
ValueError
(
f
"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f
"but got
{
datapoint_cls
}
."
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
def
_get_kernel
(
dispatcher
,
datapoint_cls
):
def
_get_kernel
(
dispatcher
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for dispatcher '
{
dispatcher
.
__name__
}
'."
)
if
datapoint_cls
in
registry
:
return
registry
[
datapoint_cls
]
for
registered_cls
,
kernel
in
registry
.
items
():
if
issubclass
(
datapoint_cls
,
registered_cls
):
return
kernel
return
_noop
raise
ValueError
(
f
"No kernel registered for dispatcher
{
dispatcher
.
__name__
}
."
)
# In case we have an exact type match, we take a shortcut.
if
input_type
in
registry
:
return
registry
[
input_type
]
# In case of datapoints, we check if we have a kernel for a superclass registered
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
for
cls
in
input_type
.
__mro__
[
1
:]:
if
cls
is
datapoints
.
Datapoint
:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif
cls
in
registry
:
return
registry
[
cls
]
# Note that in the future we are not going to return a noop here, but rather raise the error below
return
_noop
raise
TypeError
(
f
"Dispatcher
{
dispatcher
}
supports inputs of type torch.Tensor, PIL.Image.Image, "
f
"and subclasses of torchvision.datapoints.Datapoint, "
f
"but got
{
input_type
}
instead."
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
...
...
@@ -101,7 +140,9 @@ def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
f
"F.
{
dispatcher
.
__name__
}
is currently passing through inputs of type datapoints.
{
cls
.
__name__
}
. "
f
"This will likely change in the future."
)
register_kernel
(
dispatcher
,
cls
)(
functools
.
partial
(
_noop
,
__msg__
=
msg
if
warn_passthrough
else
None
))
_register_kernel_internal
(
dispatcher
,
cls
,
datapoint_wrapper
=
False
)(
functools
.
partial
(
_noop
,
__msg__
=
msg
if
warn_passthrough
else
None
)
)
return
dispatcher
return
decorator
...
...
@@ -115,13 +156,15 @@ def _noop(inpt, *args, __msg__=None, **kwargs):
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def
_register_unsupported_type
(
*
datapoints_class
es
):
def
_register_unsupported_type
(
*
input_typ
es
):
def
kernel
(
inpt
,
*
args
,
__dispatcher_name__
,
**
kwargs
):
raise
TypeError
(
f
"F.
{
__dispatcher_name__
}
does not support inputs of type
{
type
(
inpt
)
}
."
)
def
decorator
(
dispatcher
):
for
cls
in
datapoints_classes
:
register_kernel
(
dispatcher
,
cls
)(
functools
.
partial
(
kernel
,
__dispatcher_name__
=
dispatcher
.
__name__
))
for
input_type
in
input_types
:
_register_kernel_internal
(
dispatcher
,
input_type
,
datapoint_wrapper
=
False
)(
functools
.
partial
(
kernel
,
__dispatcher_name__
=
dispatcher
.
__name__
)
)
return
dispatcher
return
decorator
...
...
@@ -129,13 +172,10 @@ def _register_unsupported_type(*datapoints_classes):
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
# TODO: decide if we want that
def
_register_five_ten_crop_kernel
(
dispatcher
,
datapoint_cls
):
def
_register_five_ten_crop_kernel
(
dispatcher
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
if
datapoint_cls
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
.
__name__
}
' already has a kernel registered for type '
{
datapoint_cls
.
__name__
}
'."
)
if
input_type
in
registry
:
raise
TypeError
(
f
"Dispatcher '
{
dispatcher
}
' already has a kernel registered for type '
{
input_type
}
'."
)
def
wrap
(
kernel
):
@
functools
.
wraps
(
kernel
)
...
...
@@ -147,7 +187,7 @@ def _register_five_ten_crop_kernel(dispatcher, datapoint_cls):
return
wrapper
def
decorator
(
kernel
):
registry
[
datapoint_cls
]
=
wrap
(
kernel
)
registry
[
input_type
]
=
wrap
(
kernel
)
if
issubclass
(
input_type
,
datapoints
.
Datapoint
)
else
kernel
return
kernel
return
decorator
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment