Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a893f313
Unverified
Commit
a893f313
authored
Aug 02, 2023
by
Philip Meier
Committed by
GitHub
Aug 02, 2023
Browse files
refactor Datapoint dispatch mechanism (#7747)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
16d62e30
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
928 additions
and
1303 deletions
+928
-1303
test/common_utils.py
test/common_utils.py
+4
-0
test/datasets_utils.py
test/datasets_utils.py
+4
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+5
-5
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+19
-33
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+200
-55
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+9
-16
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+0
-40
torchvision/datapoints/__init__.py
torchvision/datapoints/__init__.py
+1
-1
torchvision/datapoints/_bounding_box.py
torchvision/datapoints/_bounding_box.py
+2
-141
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+0
-138
torchvision/datapoints/_image.py
torchvision/datapoints/_image.py
+2
-192
torchvision/datapoints/_mask.py
torchvision/datapoints/_mask.py
+2
-105
torchvision/datapoints/_video.py
torchvision/datapoints/_video.py
+2
-188
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+2
-6
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+2
-31
torchvision/transforms/v2/_temporal.py
torchvision/transforms/v2/_temporal.py
+2
-3
torchvision/transforms/v2/functional/__init__.py
torchvision/transforms/v2/functional/__init__.py
+2
-1
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+30
-30
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+160
-110
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+480
-206
No files found.
test/common_utils.py
View file @
a893f313
...
...
@@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return
datapoints
.
Video
(
make_image
(
size
,
batch_dims
=
(
*
batch_dims
,
num_frames
),
**
kwargs
))
def
make_video_tensor
(
*
args
,
**
kwargs
):
return
make_video
(
*
args
,
**
kwargs
).
as_subclass
(
torch
.
Tensor
)
def
make_video_loader
(
size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
*
,
...
...
test/datasets_utils.py
View file @
a893f313
...
...
@@ -567,7 +567,7 @@ class DatasetTestCase(unittest.TestCase):
@
test_all_configs
def
test_transforms_v2_wrapper
(
self
,
config
):
from
torchvision
.datapoints._datapoint
import
D
atapoint
from
torchvision
import
d
atapoint
s
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
try
:
...
...
@@ -588,7 +588,9 @@ class DatasetTestCase(unittest.TestCase):
assert
len
(
wrapped_dataset
)
==
info
[
"num_examples"
]
wrapped_sample
=
wrapped_dataset
[
0
]
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
Datapoint
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
datapoints
.
Datapoint
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
except
TypeError
as
error
:
msg
=
f
"No wrapper exists for dataset class
{
type
(
dataset
).
__name__
}
"
if
str
(
error
).
startswith
(
msg
):
...
...
test/test_transforms_v2.py
View file @
a893f313
...
...
@@ -1344,12 +1344,12 @@ def test_antialias_warning():
transforms
.
RandomResize
(
10
,
20
)(
tensor_img
)
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
datapoints
.
Image
(
tensor_img
)
.
resized_crop
(
0
,
0
,
10
,
10
,
(
20
,
20
))
F
.
resized_crop
(
datapoints
.
Image
(
tensor_img
)
,
0
,
0
,
10
,
10
,
(
20
,
20
))
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
datapoints
.
Video
(
tensor_video
)
.
resize
(
(
20
,
20
))
F
.
resize
(
datapoints
.
Video
(
tensor_video
)
,
(
20
,
20
))
with
pytest
.
warns
(
UserWarning
,
match
=
match
):
datapoints
.
Video
(
tensor_video
)
.
resized_crop
(
0
,
0
,
10
,
10
,
(
20
,
20
))
F
.
resized_crop
(
datapoints
.
Video
(
tensor_video
)
,
0
,
0
,
10
,
10
,
(
20
,
20
))
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"error"
)
...
...
@@ -1363,8 +1363,8 @@ def test_antialias_warning():
transforms
.
RandomShortestSize
((
20
,
20
),
antialias
=
True
)(
tensor_img
)
transforms
.
RandomResize
(
10
,
20
,
antialias
=
True
)(
tensor_img
)
datapoints
.
Image
(
tensor_img
)
.
resized_crop
(
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
datapoints
.
Video
(
tensor_video
)
.
resized_crop
(
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
datapoints
.
Image
(
tensor_img
)
,
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
F
.
resized_crop
(
datapoints
.
Video
(
tensor_video
)
,
0
,
0
,
10
,
10
,
(
20
,
20
),
antialias
=
True
)
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
...
...
test/test_transforms_v2_functional.py
View file @
a893f313
...
...
@@ -2,13 +2,11 @@ import inspect
import
math
import
os
import
re
from
typing
import
get_type_hints
from
unittest
import
mock
import
numpy
as
np
import
PIL.Image
import
pytest
import
torch
from
common_utils
import
(
...
...
@@ -27,6 +25,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
from
torchvision.transforms.v2.functional._utils
import
_KERNEL_REGISTRY
from
torchvision.transforms.v2.utils
import
is_simple_tensor
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
...
...
@@ -424,12 +423,18 @@ class TestDispatchers:
def
test_dispatch_datapoint
(
self
,
info
,
args_kwargs
,
spy_on
):
(
datapoint
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
method_name
=
info
.
id
method
=
getattr
(
datapoint
,
method_name
)
datapoint_type
=
type
(
datapoint
)
spy
=
spy_on
(
method
,
module
=
datapoint_type
.
__module__
,
name
=
f
"
{
datapoint_type
.
__name__
}
.
{
method_name
}
"
)
input_type
=
type
(
datapoint
)
wrapped_kernel
=
_KERNEL_REGISTRY
[
info
.
dispatcher
][
input_type
]
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if
hasattr
(
wrapped_kernel
,
"__wrapped__"
):
assert
wrapped_kernel
.
__wrapped__
is
info
.
kernels
[
input_type
]
spy
=
mock
.
MagicMock
(
wraps
=
wrapped_kernel
,
name
=
wrapped_kernel
.
__name__
)
with
mock
.
patch
.
dict
(
_KERNEL_REGISTRY
[
info
.
dispatcher
],
values
=
{
input_type
:
spy
}):
info
.
dispatcher
(
datapoint
,
*
other_args
,
**
kwargs
)
spy
.
assert_called_once
()
...
...
@@ -462,9 +467,12 @@ class TestDispatchers:
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel.
datapoint_type_metadata
=
datapoint_type
.
__annotations__
.
keys
()
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
datapoint_type_metadata
]
# explicitly passed to the kernel.
input_type
=
{
v
:
k
for
k
,
v
in
dispatcher_info
.
kernels
.
items
()}.
get
(
kernel_info
.
kernel
)
explicit_metadata
=
{
datapoints
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
dispatcher_params
=
iter
(
dispatcher_params
)
for
dispatcher_param
,
kernel_param
in
zip
(
dispatcher_params
,
kernel_params
):
...
...
@@ -481,28 +489,6 @@ class TestDispatchers:
assert
dispatcher_param
==
kernel_param
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
def
test_dispatcher_datapoint_signatures_consistency
(
self
,
info
):
try
:
datapoint_method
=
getattr
(
datapoints
.
_datapoint
.
Datapoint
,
info
.
id
)
except
AttributeError
:
pytest
.
skip
(
"Dispatcher doesn't support arbitrary datapoint dispatch."
)
dispatcher_signature
=
inspect
.
signature
(
info
.
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
datapoint_signature
=
inspect
.
signature
(
datapoint_method
)
datapoint_params
=
list
(
datapoint_signature
.
parameters
.
values
())[
1
:]
# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations
=
get_type_hints
(
datapoint_method
)
for
param
in
datapoint_params
:
param
.
_annotation
=
datapoint_annotations
[
param
.
name
]
assert
dispatcher_params
==
datapoint_params
@
pytest
.
mark
.
parametrize
(
"info"
,
DISPATCHER_INFOS
,
ids
=
lambda
info
:
info
.
id
)
def
test_unkown_type
(
self
,
info
):
unkown_input
=
object
()
...
...
test/test_transforms_v2_refactored.py
View file @
a893f313
...
...
@@ -3,7 +3,6 @@ import decimal
import
inspect
import
math
import
re
from
typing
import
get_type_hints
from
unittest
import
mock
import
numpy
as
np
...
...
@@ -26,6 +25,7 @@ from common_utils import (
make_image_tensor
,
make_segmentation_mask
,
make_video
,
make_video_tensor
,
needs_cuda
,
set_rng_seed
,
)
...
...
@@ -39,6 +39,7 @@ from torchvision import datapoints
from
torchvision.transforms._functional_tensor
import
_max_value
as
get_max_value
from
torchvision.transforms.functional
import
pil_modes_mapping
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._utils
import
_KERNEL_REGISTRY
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -176,16 +177,19 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved.
"""
if
isinstance
(
input
,
datapoints
.
_datapoint
.
Datapoint
):
# Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly,
# but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel.
spy
=
mock
.
MagicMock
(
wraps
=
kernel
,
name
=
kernel
.
__name__
)
with
mock
.
patch
.
object
(
F
,
kernel
.
__name__
,
spy
):
# Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class.
# Since that is not the case here, we need to prefix f"_{cls.__name__}"
# See https://docs.python.org/3/tutorial/classes.html#private-variables for details
with
mock
.
patch
.
object
(
datapoints
.
_datapoint
.
Datapoint
,
"_Datapoint__F"
,
new
=
F
):
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
input_type
=
type
(
input
)
if
isinstance
(
input
,
datapoints
.
Datapoint
):
wrapped_kernel
=
_KERNEL_REGISTRY
[
dispatcher
][
input_type
]
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if
hasattr
(
wrapped_kernel
,
"__wrapped__"
):
assert
wrapped_kernel
.
__wrapped__
is
kernel
spy
=
mock
.
MagicMock
(
wraps
=
wrapped_kernel
,
name
=
wrapped_kernel
.
__name__
)
with
mock
.
patch
.
dict
(
_KERNEL_REGISTRY
[
dispatcher
],
values
=
{
input_type
:
spy
}):
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
spy
.
assert_called_once
()
else
:
...
...
@@ -194,7 +198,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
spy
.
assert_called_once
()
assert
isinstance
(
output
,
type
(
input
)
)
assert
isinstance
(
output
,
input_type
)
if
isinstance
(
input
,
datapoints
.
BoundingBoxes
):
assert
output
.
format
==
input
.
format
...
...
@@ -209,15 +213,13 @@ def check_dispatcher(
check_dispatch
=
True
,
**
kwargs
,
):
unknown_input
=
object
()
with
mock
.
patch
(
"torch._C._log_api_usage_once"
,
wraps
=
torch
.
_C
.
_log_api_usage_once
)
as
spy
:
dispatcher
(
input
,
*
args
,
**
kwargs
)
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
dispatcher
(
unknown_input
,
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
dispatcher
.
__module__
}
.
{
dispatcher
.
__name__
}
"
)
unknown_input
=
object
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
dispatcher
(
unknown_input
,
*
args
,
**
kwargs
)
if
check_scripted_smoke
:
_check_dispatcher_scripted_smoke
(
dispatcher
,
input
,
*
args
,
**
kwargs
)
...
...
@@ -225,18 +227,18 @@ def check_dispatcher(
_check_dispatcher_dispatch
(
dispatcher
,
kernel
,
input
,
*
args
,
**
kwargs
)
def
_
check_dispatcher_kernel_signature_match
(
dispatcher
,
*
,
kernel
,
input_type
):
def
check_dispatcher_kernel_signature_match
(
dispatcher
,
*
,
kernel
,
input_type
):
"""Checks if the signature of the dispatcher matches the kernel signature."""
dispatcher_signature
=
inspect
.
signature
(
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
kernel_signature
=
inspect
.
signature
(
kernel
)
kernel_params
=
list
(
kernel_signature
.
parameters
.
values
())[
1
:]
dispatcher_params
=
list
(
inspect
.
signature
(
dispatcher
).
parameters
.
values
())[
1
:]
kernel_params
=
list
(
inspect
.
signature
(
kernel
).
parameters
.
values
())[
1
:]
if
issubclass
(
input_type
,
datapoints
.
_datapoint
.
Datapoint
):
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicitly passed to the kernel.
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
input_type
.
__annotations__
.
keys
()]
explicit_metadata
=
{
datapoints
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
dispatcher_params
=
iter
(
dispatcher_params
)
for
dispatcher_param
,
kernel_param
in
zip
(
dispatcher_params
,
kernel_params
):
...
...
@@ -259,30 +261,6 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
assert
dispatcher_param
==
kernel_param
def
_check_dispatcher_datapoint_signature_match
(
dispatcher
):
"""Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class."""
dispatcher_signature
=
inspect
.
signature
(
dispatcher
)
dispatcher_params
=
list
(
dispatcher_signature
.
parameters
.
values
())[
1
:]
datapoint_method
=
getattr
(
datapoints
.
_datapoint
.
Datapoint
,
dispatcher
.
__name__
)
datapoint_signature
=
inspect
.
signature
(
datapoint_method
)
datapoint_params
=
list
(
datapoint_signature
.
parameters
.
values
())[
1
:]
# Some annotations in the `datapoints._datapoint` module
# are stored as strings. The block below makes them concrete again (non-strings), so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations
=
get_type_hints
(
datapoint_method
)
for
param
in
datapoint_params
:
param
.
_annotation
=
datapoint_annotations
[
param
.
name
]
assert
dispatcher_params
==
datapoint_params
def
check_dispatcher_signatures_match
(
dispatcher
,
*
,
kernel
,
input_type
):
_check_dispatcher_kernel_signature_match
(
dispatcher
,
kernel
=
kernel
,
input_type
=
input_type
)
_check_dispatcher_datapoint_signature_match
(
dispatcher
)
def
_check_transform_v1_compatibility
(
transform
,
input
):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method, is scriptable, and the scripted version can be called without error."""
...
...
@@ -433,6 +411,33 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
return
torch
.
stack
([
transform
(
b
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()]).
reshape
(
bounding_boxes
.
shape
)
@
pytest
.
mark
.
parametrize
(
(
"dispatcher"
,
"registered_datapoint_clss"
),
[(
dispatcher
,
set
(
registry
.
keys
()))
for
dispatcher
,
registry
in
_KERNEL_REGISTRY
.
items
()],
)
def
test_exhaustive_kernel_registration
(
dispatcher
,
registered_datapoint_clss
):
missing
=
{
datapoints
.
Image
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
,
}
-
registered_datapoint_clss
if
missing
:
names
=
sorted
(
f
"datapoints.
{
cls
.
__name__
}
"
for
cls
in
missing
)
raise
AssertionError
(
"
\n
"
.
join
(
[
f
"The dispatcher '
{
dispatcher
.
__name__
}
' has no kernel registered for"
,
""
,
*
[
f
"-
{
name
}
"
for
name
in
names
],
""
,
f
"If available, register the kernels with @_register_kernel_internal(
{
dispatcher
.
__name__
}
, ...)."
,
f
"If not, register explicit no-ops with @_register_explicit_noop(
{
', '
.
join
(
names
)
}
)"
,
]
)
)
class
TestResize
:
INPUT_SIZE
=
(
17
,
11
)
OUTPUT_SIZES
=
[
17
,
[
17
],
(
17
,),
[
12
,
13
],
(
12
,
13
)]
...
...
@@ -568,7 +573,7 @@ class TestResize:
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_signature
s
_match
(
F
.
resize
,
kernel
=
kernel
,
input_type
=
input_type
)
check_dispatcher_
kernel_
signature_match
(
F
.
resize
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
...
...
@@ -766,7 +771,7 @@ class TestResize:
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
# is a good reason to break this, feel free to downgrade to an equality check.
if
isinstance
(
input
,
datapoints
.
_datapoint
.
Datapoint
):
if
isinstance
(
input
,
datapoints
.
Datapoint
):
# We can't test identity directly, since that checks for the identity of the Python object. Since all
# datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
# that the underlying storage is the same
...
...
@@ -850,7 +855,7 @@ class TestHorizontalFlip:
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_signature
s
_match
(
F
.
horizontal_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
check_dispatcher_
kernel_
signature_match
(
F
.
horizontal_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1033,7 +1038,7 @@ class TestAffine:
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_signature
s
_match
(
F
.
affine
,
kernel
=
kernel
,
input_type
=
input_type
)
check_dispatcher_
kernel_
signature_match
(
F
.
affine
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1329,7 +1334,7 @@ class TestVerticalFlip:
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_signature
s
_match
(
F
.
vertical_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
check_dispatcher_
kernel_
signature_match
(
F
.
vertical_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1486,7 +1491,7 @@ class TestRotate:
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_signature
s
_match
(
F
.
rotate
,
kernel
=
kernel
,
input_type
=
input_type
)
check_dispatcher_
kernel_
signature_match
(
F
.
rotate
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1899,6 +1904,56 @@ class TestToDtype:
assert
out
[
"mask"
].
dtype
==
mask_dtype
class
TestAdjustBrightness
:
_CORRECTNESS_BRIGHTNESS_FACTORS
=
[
0.5
,
0.0
,
1.0
,
5.0
]
_DEFAULT_BRIGHTNESS_FACTOR
=
_CORRECTNESS_BRIGHTNESS_FACTORS
[
0
]
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
adjust_brightness_image_tensor
,
make_image
),
(
F
.
adjust_brightness_video
,
make_video
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_kernel
(
self
,
kernel
,
make_input
,
dtype
,
device
):
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
adjust_brightness_image_tensor
,
make_image_tensor
),
(
F
.
adjust_brightness_image_pil
,
make_image_pil
),
(
F
.
adjust_brightness_image_tensor
,
make_image
),
(
F
.
adjust_brightness_video
,
make_video
),
],
)
def
test_dispatcher
(
self
,
kernel
,
make_input
):
check_dispatcher
(
F
.
adjust_brightness
,
kernel
,
make_input
(),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
[
(
F
.
adjust_brightness_image_tensor
,
torch
.
Tensor
),
(
F
.
adjust_brightness_image_pil
,
PIL
.
Image
.
Image
),
(
F
.
adjust_brightness_image_tensor
,
datapoints
.
Image
),
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
],
)
def
test_dispatcher_signature
(
self
,
kernel
,
input_type
):
check_dispatcher_kernel_signature_match
(
F
.
adjust_brightness
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"brightness_factor"
,
_CORRECTNESS_BRIGHTNESS_FACTORS
)
def
test_image_correctness
(
self
,
brightness_factor
):
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
actual
=
F
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
expected
=
F
.
to_image_tensor
(
F
.
adjust_brightness
(
F
.
to_image_pil
(
image
),
brightness_factor
=
brightness_factor
))
torch
.
testing
.
assert_close
(
actual
,
expected
)
class
TestCutMixMixUp
:
class
DummyDataset
:
def
__init__
(
self
,
size
,
num_classes
):
...
...
@@ -2036,3 +2091,93 @@ def test_labels_getter_default_heuristic(key, sample_type):
# it takes precedence over other keys which would otherwise be a match
d
=
{
key
:
"something_else"
,
"labels"
:
labels
}
assert
transforms
.
_utils
.
_find_labels_default_heuristic
(
d
)
is
labels
class
TestShapeGetters
:
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_dimensions_image_tensor
,
make_image_tensor
),
(
F
.
get_dimensions_image_pil
,
make_image_pil
),
(
F
.
get_dimensions_image_tensor
,
make_image
),
(
F
.
get_dimensions_video
,
make_video
),
],
)
def
test_get_dimensions
(
self
,
kernel
,
make_input
):
size
=
(
10
,
10
)
color_space
,
num_channels
=
"RGB"
,
3
input
=
make_input
(
size
,
color_space
=
color_space
)
assert
kernel
(
input
)
==
F
.
get_dimensions
(
input
)
==
[
num_channels
,
*
size
]
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_num_channels_image_tensor
,
make_image_tensor
),
(
F
.
get_num_channels_image_pil
,
make_image_pil
),
(
F
.
get_num_channels_image_tensor
,
make_image
),
(
F
.
get_num_channels_video
,
make_video
),
],
)
def
test_get_num_channels
(
self
,
kernel
,
make_input
):
color_space
,
num_channels
=
"RGB"
,
3
input
=
make_input
(
color_space
=
color_space
)
assert
kernel
(
input
)
==
F
.
get_num_channels
(
input
)
==
num_channels
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_size_image_tensor
,
make_image_tensor
),
(
F
.
get_size_image_pil
,
make_image_pil
),
(
F
.
get_size_image_tensor
,
make_image
),
(
F
.
get_size_bounding_boxes
,
make_bounding_box
),
(
F
.
get_size_mask
,
make_detection_mask
),
(
F
.
get_size_mask
,
make_segmentation_mask
),
(
F
.
get_size_video
,
make_video
),
],
)
def
test_get_size
(
self
,
kernel
,
make_input
):
size
=
(
10
,
10
)
input
=
make_input
(
size
)
assert
kernel
(
input
)
==
F
.
get_size
(
input
)
==
list
(
size
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"make_input"
),
[
(
F
.
get_num_frames_video
,
make_video_tensor
),
(
F
.
get_num_frames_video
,
make_video
),
],
)
def
test_get_num_frames
(
self
,
kernel
,
make_input
):
num_frames
=
4
input
=
make_input
(
num_frames
=
num_frames
)
assert
kernel
(
input
)
==
F
.
get_num_frames
(
input
)
==
num_frames
@
pytest
.
mark
.
parametrize
(
(
"dispatcher"
,
"make_input"
),
[
(
F
.
get_dimensions
,
make_bounding_box
),
(
F
.
get_dimensions
,
make_detection_mask
),
(
F
.
get_dimensions
,
make_segmentation_mask
),
(
F
.
get_num_channels
,
make_bounding_box
),
(
F
.
get_num_channels
,
make_detection_mask
),
(
F
.
get_num_channels
,
make_segmentation_mask
),
(
F
.
get_num_frames
,
make_image_pil
),
(
F
.
get_num_frames
,
make_image
),
(
F
.
get_num_frames
,
make_bounding_box
),
(
F
.
get_num_frames
,
make_detection_mask
),
(
F
.
get_num_frames
,
make_segmentation_mask
),
],
)
def
test_unsupported_types
(
self
,
dispatcher
,
make_input
):
input
=
make_input
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
input
)))):
dispatcher
(
input
)
test/transforms_v2_dispatcher_infos.py
View file @
a893f313
...
...
@@ -69,14 +69,15 @@ class DispatcherInfo(InfoBase):
import
itertools
for
args_kwargs
in
sample_inputs
:
for
name
in
itertools
.
chain
(
datapoint_type
.
__annotations__
.
keys
(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(
f
"old_
{
name
}
"
for
name
in
datapoint_type
.
__annotations__
.
keys
()),
):
if
name
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
name
]
if
hasattr
(
datapoint_type
,
"__annotations__"
):
for
name
in
itertools
.
chain
(
datapoint_type
.
__annotations__
.
keys
(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(
f
"old_
{
name
}
"
for
name
in
datapoint_type
.
__annotations__
.
keys
()),
):
if
name
in
args_kwargs
.
kwargs
:
del
args_kwargs
.
kwargs
[
name
]
yield
args_kwargs
...
...
@@ -289,14 +290,6 @@ DISPATCHER_INFOS = [
skip_dispatch_datapoint
,
],
),
DispatcherInfo
(
F
.
adjust_brightness
,
kernels
=
{
datapoints
.
Image
:
F
.
adjust_brightness_image_tensor
,
datapoints
.
Video
:
F
.
adjust_brightness_video
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
adjust_brightness_image_pil
,
kernel_name
=
"adjust_brightness_image_pil"
),
),
DispatcherInfo
(
F
.
adjust_contrast
,
kernels
=
{
...
...
test/transforms_v2_kernel_infos.py
View file @
a893f313
...
...
@@ -1259,46 +1259,6 @@ KERNEL_INFOS.extend(
]
)
_ADJUST_BRIGHTNESS_FACTORS
=
[
0.1
,
0.5
]
def
sample_inputs_adjust_brightness_image_tensor
():
for
image_loader
in
make_image_loaders
(
sizes
=
[
DEFAULT_PORTRAIT_SPATIAL_SIZE
],
color_spaces
=
(
"GRAY"
,
"RGB"
)):
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
_ADJUST_BRIGHTNESS_FACTORS
[
0
])
def
reference_inputs_adjust_brightness_image_tensor
():
for
image_loader
,
brightness_factor
in
itertools
.
product
(
make_image_loaders
(
color_spaces
=
(
"GRAY"
,
"RGB"
),
extra_dims
=
[()],
dtypes
=
[
torch
.
uint8
]),
_ADJUST_BRIGHTNESS_FACTORS
,
):
yield
ArgsKwargs
(
image_loader
,
brightness_factor
=
brightness_factor
)
def
sample_inputs_adjust_brightness_video
():
for
video_loader
in
make_video_loaders
(
sizes
=
[
DEFAULT_PORTRAIT_SPATIAL_SIZE
],
num_frames
=
[
3
]):
yield
ArgsKwargs
(
video_loader
,
brightness_factor
=
_ADJUST_BRIGHTNESS_FACTORS
[
0
])
KERNEL_INFOS
.
extend
(
[
KernelInfo
(
F
.
adjust_brightness_image_tensor
,
kernel_name
=
"adjust_brightness_image_tensor"
,
sample_inputs_fn
=
sample_inputs_adjust_brightness_image_tensor
,
reference_fn
=
pil_reference_wrapper
(
F
.
adjust_brightness_image_pil
),
reference_inputs_fn
=
reference_inputs_adjust_brightness_image_tensor
,
float32_vs_uint8
=
True
,
closeness_kwargs
=
float32_vs_uint8_pixel_difference
(),
),
KernelInfo
(
F
.
adjust_brightness_video
,
sample_inputs_fn
=
sample_inputs_adjust_brightness_video
,
),
]
)
_ADJUST_CONTRAST_FACTORS
=
[
0.1
,
0.5
]
...
...
torchvision/datapoints/__init__.py
View file @
a893f313
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._datapoint
import
_FillType
,
_FillTypeJIT
,
_InputType
,
_InputTypeJIT
from
._datapoint
import
_FillType
,
_FillTypeJIT
,
_InputType
,
_InputTypeJIT
,
Datapoint
from
._image
import
_ImageType
,
_ImageTypeJIT
,
_TensorImageType
,
_TensorImageTypeJIT
,
Image
from
._mask
import
Mask
from
._video
import
_TensorVideoType
,
_TensorVideoTypeJIT
,
_VideoType
,
_VideoTypeJIT
,
Video
...
...
torchvision/datapoints/_bounding_box.py
View file @
a893f313
from
__future__
import
annotations
from
enum
import
Enum
from
typing
import
Any
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
from
torchvision.transforms
import
InterpolationMode
# TODO: this needs to be moved out of transforms
from
._datapoint
import
_FillTypeJIT
,
Datapoint
from
._datapoint
import
Datapoint
class
BoundingBoxFormat
(
Enum
):
...
...
@@ -97,141 +96,3 @@ class BoundingBoxes(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
)
def
horizontal_flip
(
self
)
->
BoundingBoxes
:
output
=
self
.
_F
.
horizontal_flip_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
)
def
vertical_flip
(
self
)
->
BoundingBoxes
:
output
=
self
.
_F
.
vertical_flip_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
)
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBoxes
:
output
,
canvas_size
=
self
.
_F
.
resize_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
canvas_size
=
self
.
canvas_size
,
size
=
size
,
max_size
=
max_size
,
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
,
canvas_size
=
canvas_size
)
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
BoundingBoxes
:
output
,
canvas_size
=
self
.
_F
.
crop_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
,
canvas_size
=
canvas_size
)
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
BoundingBoxes
:
output
,
canvas_size
=
self
.
_F
.
center_crop_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
,
output_size
=
output_size
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
,
canvas_size
=
canvas_size
)
def
resized_crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBoxes
:
output
,
canvas_size
=
self
.
_F
.
resized_crop_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
top
,
left
,
height
,
width
,
size
=
size
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
,
canvas_size
=
canvas_size
)
def
pad
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
BoundingBoxes
:
output
,
canvas_size
=
self
.
_F
.
pad_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
,
padding
=
padding
,
padding_mode
=
padding_mode
,
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
,
canvas_size
=
canvas_size
)
def
rotate
(
self
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
BoundingBoxes
:
output
,
canvas_size
=
self
.
_F
.
rotate_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
,
angle
=
angle
,
expand
=
expand
,
center
=
center
,
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
,
canvas_size
=
canvas_size
)
def
affine
(
self
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
BoundingBoxes
:
output
=
self
.
_F
.
affine_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
self
.
canvas_size
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
center
=
center
,
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
)
def
perspective
(
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
BoundingBoxes
:
output
=
self
.
_F
.
perspective_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
,
startpoints
=
startpoints
,
endpoints
=
endpoints
,
coefficients
=
coefficients
,
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
)
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
)
->
BoundingBoxes
:
output
=
self
.
_F
.
elastic_bounding_boxes
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
self
.
canvas_size
,
displacement
=
displacement
)
return
BoundingBoxes
.
wrap_like
(
self
,
output
)
torchvision/datapoints/_datapoint.py
View file @
a893f313
from
__future__
import
annotations
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
PIL.Image
import
torch
from
torch._C
import
DisableTorchFunctionSubclass
from
torch.types
import
_device
,
_dtype
,
_size
from
torchvision.transforms
import
InterpolationMode
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
...
...
@@ -16,8 +14,6 @@ _FillTypeJIT = Optional[List[float]]
class
Datapoint
(
torch
.
Tensor
):
__F
:
Optional
[
ModuleType
]
=
None
@
staticmethod
def
_to_tensor
(
data
:
Any
,
...
...
@@ -99,18 +95,6 @@ class Datapoint(torch.Tensor):
extra_repr
=
", "
.
join
(
f
"
{
key
}
=
{
value
}
"
for
key
,
value
in
kwargs
.
items
())
return
f
"
{
super
().
__repr__
()[:
-
1
]
}
,
{
extra_repr
}
)"
@
property
def
_F
(
self
)
->
ModuleType
:
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred
# until the first time we need reference to the functional module and it's shared across all instances of
# the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if
Datapoint
.
__F
is
None
:
from
..transforms.v2
import
functional
Datapoint
.
__F
=
functional
return
Datapoint
.
__F
# Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__
@
property
...
...
@@ -142,128 +126,6 @@ class Datapoint(torch.Tensor):
# `BoundingBoxes.clone()`.
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
def
horizontal_flip
(
self
)
->
Datapoint
:
return
self
def
vertical_flip
(
self
)
->
Datapoint
:
return
self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Datapoint
:
return
self
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
Datapoint
:
return
self
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
Datapoint
:
return
self
def
resized_crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Datapoint
:
return
self
def
pad
(
self
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
Datapoint
:
return
self
def
rotate
(
self
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
Datapoint
:
return
self
def
affine
(
self
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Datapoint
:
return
self
def
perspective
(
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Datapoint
:
return
self
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
)
->
Datapoint
:
return
self
def
rgb_to_grayscale
(
self
,
num_output_channels
:
int
=
1
)
->
Datapoint
:
return
self
def
adjust_brightness
(
self
,
brightness_factor
:
float
)
->
Datapoint
:
return
self
def
adjust_saturation
(
self
,
saturation_factor
:
float
)
->
Datapoint
:
return
self
def
adjust_contrast
(
self
,
contrast_factor
:
float
)
->
Datapoint
:
return
self
def
adjust_sharpness
(
self
,
sharpness_factor
:
float
)
->
Datapoint
:
return
self
def
adjust_hue
(
self
,
hue_factor
:
float
)
->
Datapoint
:
return
self
def
adjust_gamma
(
self
,
gamma
:
float
,
gain
:
float
=
1
)
->
Datapoint
:
return
self
def
posterize
(
self
,
bits
:
int
)
->
Datapoint
:
return
self
def
solarize
(
self
,
threshold
:
float
)
->
Datapoint
:
return
self
def
autocontrast
(
self
)
->
Datapoint
:
return
self
def
equalize
(
self
)
->
Datapoint
:
return
self
def
invert
(
self
)
->
Datapoint
:
return
self
def
gaussian_blur
(
self
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
Datapoint
:
return
self
_InputType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Datapoint
]
_InputTypeJIT
=
torch
.
Tensor
torchvision/datapoints/_image.py
View file @
a893f313
from
__future__
import
annotations
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
PIL.Image
import
torch
from
torchvision.transforms.functional
import
InterpolationMode
from
._datapoint
import
_FillTypeJIT
,
Datapoint
from
._datapoint
import
Datapoint
class
Image
(
Datapoint
):
...
...
@@ -56,195 +55,6 @@ class Image(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
def
horizontal_flip
(
self
)
->
Image
:
output
=
self
.
_F
.
horizontal_flip_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Image
.
wrap_like
(
self
,
output
)
def
vertical_flip
(
self
)
->
Image
:
output
=
self
.
_F
.
vertical_flip_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Image
.
wrap_like
(
self
,
output
)
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Image
:
output
=
self
.
_F
.
resize_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
Image
.
wrap_like
(
self
,
output
)
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
Image
:
output
=
self
.
_F
.
crop_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
)
return
Image
.
wrap_like
(
self
,
output
)
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
Image
:
output
=
self
.
_F
.
center_crop_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
output_size
=
output_size
)
return
Image
.
wrap_like
(
self
,
output
)
def
resized_crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Image
:
output
=
self
.
_F
.
resized_crop_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
,
size
=
list
(
size
),
interpolation
=
interpolation
,
antialias
=
antialias
,
)
return
Image
.
wrap_like
(
self
,
output
)
def
pad
(
self
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
Image
:
output
=
self
.
_F
.
pad_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
Image
.
wrap_like
(
self
,
output
)
def
rotate
(
self
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
Image
:
output
=
self
.
_F
.
rotate_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
Image
.
wrap_like
(
self
,
output
)
def
affine
(
self
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Image
:
output
=
self
.
_F
.
affine_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
return
Image
.
wrap_like
(
self
,
output
)
def
perspective
(
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Image
:
output
=
self
.
_F
.
perspective_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
,
)
return
Image
.
wrap_like
(
self
,
output
)
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
)
->
Image
:
output
=
self
.
_F
.
elastic_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
Image
.
wrap_like
(
self
,
output
)
def
rgb_to_grayscale
(
self
,
num_output_channels
:
int
=
1
)
->
Image
:
output
=
self
.
_F
.
rgb_to_grayscale_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
num_output_channels
=
num_output_channels
)
return
Image
.
wrap_like
(
self
,
output
)
def
adjust_brightness
(
self
,
brightness_factor
:
float
)
->
Image
:
output
=
self
.
_F
.
adjust_brightness_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
brightness_factor
=
brightness_factor
)
return
Image
.
wrap_like
(
self
,
output
)
def
adjust_saturation
(
self
,
saturation_factor
:
float
)
->
Image
:
output
=
self
.
_F
.
adjust_saturation_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
saturation_factor
=
saturation_factor
)
return
Image
.
wrap_like
(
self
,
output
)
def
adjust_contrast
(
self
,
contrast_factor
:
float
)
->
Image
:
output
=
self
.
_F
.
adjust_contrast_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
contrast_factor
=
contrast_factor
)
return
Image
.
wrap_like
(
self
,
output
)
def
adjust_sharpness
(
self
,
sharpness_factor
:
float
)
->
Image
:
output
=
self
.
_F
.
adjust_sharpness_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
sharpness_factor
=
sharpness_factor
)
return
Image
.
wrap_like
(
self
,
output
)
def
adjust_hue
(
self
,
hue_factor
:
float
)
->
Image
:
output
=
self
.
_F
.
adjust_hue_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
hue_factor
=
hue_factor
)
return
Image
.
wrap_like
(
self
,
output
)
def
adjust_gamma
(
self
,
gamma
:
float
,
gain
:
float
=
1
)
->
Image
:
output
=
self
.
_F
.
adjust_gamma_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
gamma
=
gamma
,
gain
=
gain
)
return
Image
.
wrap_like
(
self
,
output
)
def
posterize
(
self
,
bits
:
int
)
->
Image
:
output
=
self
.
_F
.
posterize_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
bits
=
bits
)
return
Image
.
wrap_like
(
self
,
output
)
def
solarize
(
self
,
threshold
:
float
)
->
Image
:
output
=
self
.
_F
.
solarize_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
threshold
=
threshold
)
return
Image
.
wrap_like
(
self
,
output
)
def
autocontrast
(
self
)
->
Image
:
output
=
self
.
_F
.
autocontrast_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Image
.
wrap_like
(
self
,
output
)
def
equalize
(
self
)
->
Image
:
output
=
self
.
_F
.
equalize_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Image
.
wrap_like
(
self
,
output
)
def
invert
(
self
)
->
Image
:
output
=
self
.
_F
.
invert_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Image
.
wrap_like
(
self
,
output
)
def
gaussian_blur
(
self
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
Image
:
output
=
self
.
_F
.
gaussian_blur_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
Image
.
wrap_like
(
self
,
output
)
def
normalize
(
self
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
Image
:
output
=
self
.
_F
.
normalize_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
return
Image
.
wrap_like
(
self
,
output
)
_ImageType
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
Image
]
_ImageTypeJIT
=
torch
.
Tensor
...
...
torchvision/datapoints/_mask.py
View file @
a893f313
from
__future__
import
annotations
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
PIL.Image
import
torch
from
torchvision.transforms
import
InterpolationMode
from
._datapoint
import
_FillTypeJIT
,
Datapoint
from
._datapoint
import
Datapoint
class
Mask
(
Datapoint
):
...
...
@@ -50,105 +49,3 @@ class Mask(Datapoint):
tensor
:
torch
.
Tensor
,
)
->
Mask
:
return
cls
.
_wrap
(
tensor
)
def
horizontal_flip
(
self
)
->
Mask
:
output
=
self
.
_F
.
horizontal_flip_mask
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Mask
.
wrap_like
(
self
,
output
)
def
vertical_flip
(
self
)
->
Mask
:
output
=
self
.
_F
.
vertical_flip_mask
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Mask
.
wrap_like
(
self
,
output
)
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Mask
:
output
=
self
.
_F
.
resize_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
size
,
max_size
=
max_size
)
return
Mask
.
wrap_like
(
self
,
output
)
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
Mask
:
output
=
self
.
_F
.
crop_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
)
return
Mask
.
wrap_like
(
self
,
output
)
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
Mask
:
output
=
self
.
_F
.
center_crop_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
output_size
=
output_size
)
return
Mask
.
wrap_like
(
self
,
output
)
def
resized_crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Mask
:
output
=
self
.
_F
.
resized_crop_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
,
size
=
size
)
return
Mask
.
wrap_like
(
self
,
output
)
def
pad
(
self
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
Mask
:
output
=
self
.
_F
.
pad_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
padding
,
padding_mode
=
padding_mode
,
fill
=
fill
)
return
Mask
.
wrap_like
(
self
,
output
)
def
rotate
(
self
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
Mask
:
output
=
self
.
_F
.
rotate_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
angle
,
expand
=
expand
,
center
=
center
,
fill
=
fill
)
return
Mask
.
wrap_like
(
self
,
output
)
def
affine
(
self
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Mask
:
output
=
self
.
_F
.
affine_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
fill
=
fill
,
center
=
center
,
)
return
Mask
.
wrap_like
(
self
,
output
)
def
perspective
(
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Mask
:
output
=
self
.
_F
.
perspective_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
startpoints
,
endpoints
,
fill
=
fill
,
coefficients
=
coefficients
)
return
Mask
.
wrap_like
(
self
,
output
)
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
)
->
Mask
:
output
=
self
.
_F
.
elastic_mask
(
self
.
as_subclass
(
torch
.
Tensor
),
displacement
,
fill
=
fill
)
return
Mask
.
wrap_like
(
self
,
output
)
torchvision/datapoints/_video.py
View file @
a893f313
from
__future__
import
annotations
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
from
torchvision.transforms.functional
import
InterpolationMode
from
._datapoint
import
_FillTypeJIT
,
Datapoint
from
._datapoint
import
Datapoint
class
Video
(
Datapoint
):
...
...
@@ -46,191 +45,6 @@ class Video(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
def
horizontal_flip
(
self
)
->
Video
:
output
=
self
.
_F
.
horizontal_flip_video
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Video
.
wrap_like
(
self
,
output
)
def
vertical_flip
(
self
)
->
Video
:
output
=
self
.
_F
.
vertical_flip_video
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Video
.
wrap_like
(
self
,
output
)
def
resize
(
# type: ignore[override]
self
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Video
:
output
=
self
.
_F
.
resize_video
(
self
.
as_subclass
(
torch
.
Tensor
),
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
,
)
return
Video
.
wrap_like
(
self
,
output
)
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
Video
:
output
=
self
.
_F
.
crop_video
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
)
return
Video
.
wrap_like
(
self
,
output
)
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
Video
:
output
=
self
.
_F
.
center_crop_video
(
self
.
as_subclass
(
torch
.
Tensor
),
output_size
=
output_size
)
return
Video
.
wrap_like
(
self
,
output
)
def
resized_crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
Video
:
output
=
self
.
_F
.
resized_crop_video
(
self
.
as_subclass
(
torch
.
Tensor
),
top
,
left
,
height
,
width
,
size
=
list
(
size
),
interpolation
=
interpolation
,
antialias
=
antialias
,
)
return
Video
.
wrap_like
(
self
,
output
)
def
pad
(
self
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
Video
:
output
=
self
.
_F
.
pad_video
(
self
.
as_subclass
(
torch
.
Tensor
),
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
Video
.
wrap_like
(
self
,
output
)
def
rotate
(
self
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
Video
:
output
=
self
.
_F
.
rotate_video
(
self
.
as_subclass
(
torch
.
Tensor
),
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
Video
.
wrap_like
(
self
,
output
)
def
affine
(
self
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
Video
:
output
=
self
.
_F
.
affine_video
(
self
.
as_subclass
(
torch
.
Tensor
),
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
return
Video
.
wrap_like
(
self
,
output
)
def
perspective
(
self
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
Video
:
output
=
self
.
_F
.
perspective_video
(
self
.
as_subclass
(
torch
.
Tensor
),
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
,
)
return
Video
.
wrap_like
(
self
,
output
)
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
)
->
Video
:
output
=
self
.
_F
.
elastic_video
(
self
.
as_subclass
(
torch
.
Tensor
),
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
Video
.
wrap_like
(
self
,
output
)
def
rgb_to_grayscale
(
self
,
num_output_channels
:
int
=
1
)
->
Video
:
output
=
self
.
_F
.
rgb_to_grayscale_image_tensor
(
self
.
as_subclass
(
torch
.
Tensor
),
num_output_channels
=
num_output_channels
)
return
Video
.
wrap_like
(
self
,
output
)
def
adjust_brightness
(
self
,
brightness_factor
:
float
)
->
Video
:
output
=
self
.
_F
.
adjust_brightness_video
(
self
.
as_subclass
(
torch
.
Tensor
),
brightness_factor
=
brightness_factor
)
return
Video
.
wrap_like
(
self
,
output
)
def
adjust_saturation
(
self
,
saturation_factor
:
float
)
->
Video
:
output
=
self
.
_F
.
adjust_saturation_video
(
self
.
as_subclass
(
torch
.
Tensor
),
saturation_factor
=
saturation_factor
)
return
Video
.
wrap_like
(
self
,
output
)
def
adjust_contrast
(
self
,
contrast_factor
:
float
)
->
Video
:
output
=
self
.
_F
.
adjust_contrast_video
(
self
.
as_subclass
(
torch
.
Tensor
),
contrast_factor
=
contrast_factor
)
return
Video
.
wrap_like
(
self
,
output
)
def
adjust_sharpness
(
self
,
sharpness_factor
:
float
)
->
Video
:
output
=
self
.
_F
.
adjust_sharpness_video
(
self
.
as_subclass
(
torch
.
Tensor
),
sharpness_factor
=
sharpness_factor
)
return
Video
.
wrap_like
(
self
,
output
)
def
adjust_hue
(
self
,
hue_factor
:
float
)
->
Video
:
output
=
self
.
_F
.
adjust_hue_video
(
self
.
as_subclass
(
torch
.
Tensor
),
hue_factor
=
hue_factor
)
return
Video
.
wrap_like
(
self
,
output
)
def
adjust_gamma
(
self
,
gamma
:
float
,
gain
:
float
=
1
)
->
Video
:
output
=
self
.
_F
.
adjust_gamma_video
(
self
.
as_subclass
(
torch
.
Tensor
),
gamma
=
gamma
,
gain
=
gain
)
return
Video
.
wrap_like
(
self
,
output
)
def
posterize
(
self
,
bits
:
int
)
->
Video
:
output
=
self
.
_F
.
posterize_video
(
self
.
as_subclass
(
torch
.
Tensor
),
bits
=
bits
)
return
Video
.
wrap_like
(
self
,
output
)
def
solarize
(
self
,
threshold
:
float
)
->
Video
:
output
=
self
.
_F
.
solarize_video
(
self
.
as_subclass
(
torch
.
Tensor
),
threshold
=
threshold
)
return
Video
.
wrap_like
(
self
,
output
)
def
autocontrast
(
self
)
->
Video
:
output
=
self
.
_F
.
autocontrast_video
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Video
.
wrap_like
(
self
,
output
)
def
equalize
(
self
)
->
Video
:
output
=
self
.
_F
.
equalize_video
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Video
.
wrap_like
(
self
,
output
)
def
invert
(
self
)
->
Video
:
output
=
self
.
_F
.
invert_video
(
self
.
as_subclass
(
torch
.
Tensor
))
return
Video
.
wrap_like
(
self
,
output
)
def
gaussian_blur
(
self
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
Video
:
output
=
self
.
_F
.
gaussian_blur_video
(
self
.
as_subclass
(
torch
.
Tensor
),
kernel_size
=
kernel_size
,
sigma
=
sigma
)
return
Video
.
wrap_like
(
self
,
output
)
def
normalize
(
self
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
Video
:
output
=
self
.
_F
.
normalize_video
(
self
.
as_subclass
(
torch
.
Tensor
),
mean
=
mean
,
std
=
std
,
inplace
=
inplace
)
return
Video
.
wrap_like
(
self
,
output
)
_VideoType
=
Union
[
torch
.
Tensor
,
Video
]
_VideoTypeJIT
=
torch
.
Tensor
...
...
torchvision/transforms/v2/_augment.py
View file @
a893f313
import
math
import
numbers
import
warnings
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
PIL.Image
import
torch
...
...
@@ -56,8 +56,6 @@ class RandomErasing(_RandomApplyTransform):
value
=
"random"
if
self
.
value
is
None
else
self
.
value
,
)
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Image
,
PIL
.
Image
.
Image
,
datapoints
.
Video
)
def
__init__
(
self
,
p
:
float
=
0.5
,
...
...
@@ -131,9 +129,7 @@ class RandomErasing(_RandomApplyTransform):
return
dict
(
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
)
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
]:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
params
[
"v"
]
is
not
None
:
inpt
=
F
.
erase
(
inpt
,
**
params
,
inplace
=
self
.
inplace
)
...
...
torchvision/transforms/v2/_geometry.py
View file @
a893f313
...
...
@@ -355,20 +355,11 @@ class FiveCrop(Transform):
_v1_transform_cls
=
_transforms
.
FiveCrop
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]])
->
None
:
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_transform
(
self
,
inpt
:
ImageOrVideoTypeJIT
,
params
:
Dict
[
str
,
Any
]
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
]:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
five_crop
(
inpt
,
self
.
size
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
...
...
@@ -402,13 +393,6 @@ class TenCrop(Transform):
_v1_transform_cls
=
_transforms
.
TenCrop
_transformed_types
=
(
datapoints
.
Image
,
PIL
.
Image
.
Image
,
is_simple_tensor
,
datapoints
.
Video
,
)
def
__init__
(
self
,
size
:
Union
[
int
,
Sequence
[
int
]],
vertical_flip
:
bool
=
False
)
->
None
:
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
...
...
@@ -418,20 +402,7 @@ class TenCrop(Transform):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
raise
TypeError
(
f
"BoundingBoxes'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
def
_transform
(
self
,
inpt
:
Union
[
datapoints
.
_ImageType
,
datapoints
.
_VideoType
],
params
:
Dict
[
str
,
Any
]
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
]:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
F
.
ten_crop
(
inpt
,
self
.
size
,
vertical_flip
=
self
.
vertical_flip
)
...
...
torchvision/transforms/v2/_temporal.py
View file @
a893f313
from
typing
import
Any
,
Dict
import
torch
from
torchvision
import
datapoints
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2.utils
import
is_simple_tensor
class
UniformTemporalSubsample
(
Transform
):
"""[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video.
...
...
@@ -20,7 +19,7 @@ class UniformTemporalSubsample(Transform):
num_samples (int): The number of equispaced samples to be selected
"""
_transformed_types
=
(
is_simple_tensor
,
datapoints
.
Video
)
_transformed_types
=
(
torch
.
Tensor
,
)
def
__init__
(
self
,
num_samples
:
int
):
super
().
__init__
()
...
...
torchvision/transforms/v2/functional/__init__.py
View file @
a893f313
from
torchvision.transforms
import
InterpolationMode
# usort: skip
from
._utils
import
is_simple_tensor
# usort: skip
from
._utils
import
is_simple_tensor
,
register_kernel
# usort: skip
from
._meta
import
(
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_dimensions_image_tensor
,
get_dimensions_image_pil
,
get_dimensions_video
,
get_dimensions
,
get_num_frames_video
,
get_num_frames
,
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
a893f313
...
...
@@ -7,9 +7,37 @@ from torchvision import datapoints
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
from
._utils
import
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
is_simple_tensor
@
_register_explicit_noop
(
datapoints
.
Mask
,
datapoints
.
BoundingBoxes
,
warn_passthrough
=
True
)
def
erase
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
erase
,
type
(
inpt
))
return
kernel
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
erase_image_pil
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
erase
,
datapoints
.
Image
)
def
erase_image_tensor
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -29,36 +57,8 @@ def erase_image_pil(
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
@
_register_kernel_internal
(
erase
,
datapoints
.
Video
)
def
erase_video
(
video
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
erase_image_tensor
(
video
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
def
erase
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
erase
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
erase_image_tensor
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
erase_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
datapoints
.
Image
.
wrap_like
(
inpt
,
output
)
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
erase_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
return
datapoints
.
Video
.
wrap_like
(
inpt
,
output
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
erase_image_pil
(
inpt
,
i
=
i
,
j
=
j
,
h
=
h
,
w
=
w
,
v
=
v
,
inplace
=
inplace
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
torchvision/transforms/v2/functional/_color.py
View file @
a893f313
...
...
@@ -10,7 +10,34 @@ from torchvision.transforms._functional_tensor import _max_value
from
torchvision.utils
import
_log_api_usage_once
from
._misc
import
_num_value_bits
,
to_dtype_image_tensor
from
._utils
import
is_simple_tensor
from
._utils
import
_get_kernel
,
_register_explicit_noop
,
_register_kernel_internal
,
is_simple_tensor
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
datapoints
.
Video
)
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rgb_to_grayscale
)
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
rgb_to_grayscale
,
type
(
inpt
))
return
kernel
(
inpt
,
num_output_channels
=
num_output_channels
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
rgb_to_grayscale_image_pil
(
inpt
,
num_output_channels
=
num_output_channels
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a
# superset in terms of functionality and has the same signature, we alias here to avoid disruption.
to_grayscale
=
rgb_to_grayscale
def
_rgb_to_grayscale_image_tensor
(
...
...
@@ -29,6 +56,7 @@ def _rgb_to_grayscale_image_tensor(
return
l_img
@
_register_kernel_internal
(
rgb_to_grayscale
,
datapoints
.
Image
)
def
rgb_to_grayscale_image_tensor
(
image
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
return
_rgb_to_grayscale_image_tensor
(
image
,
num_output_channels
=
num_output_channels
,
preserve_dtype
=
True
)
...
...
@@ -36,19 +64,26 @@ def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int
rgb_to_grayscale_image_pil
=
_FP
.
to_grayscale
def
rgb_to_grayscale
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
num_output_channels
:
int
=
1
)
->
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]:
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
ratio
=
float
(
ratio
)
fp
=
image1
.
is_floating_point
()
bound
=
_max_value
(
image1
.
dtype
)
output
=
image1
.
mul
(
ratio
).
add_
(
image2
,
alpha
=
(
1.0
-
ratio
)).
clamp_
(
0
,
bound
)
return
output
if
fp
else
output
.
to
(
image1
.
dtype
)
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_brightness
(
inpt
:
datapoints
.
_InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rgb_to_grayscale
)
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
_log_api_usage_once
(
adjust_brightness
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
rgb_to_grayscale_image_tensor
(
inpt
,
num_output_channels
=
num_output_channels
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
rgb_to_grayscale
(
num_output_channels
=
num_output_channels
)
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_brightness
,
type
(
inpt
))
return
kernel
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
rgb_to_grayscale_image_pil
(
inpt
,
num_output_channels
=
num_output_channels
)
return
adjust_brightness_image_pil
(
inpt
,
brightness_factor
=
brightness_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -56,19 +91,7 @@ def rgb_to_grayscale(
)
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a
# superset in terms of functionality and has the same signature, we alias here to avoid disruption.
to_grayscale
=
rgb_to_grayscale
def
_blend
(
image1
:
torch
.
Tensor
,
image2
:
torch
.
Tensor
,
ratio
:
float
)
->
torch
.
Tensor
:
ratio
=
float
(
ratio
)
fp
=
image1
.
is_floating_point
()
bound
=
_max_value
(
image1
.
dtype
)
output
=
image1
.
mul
(
ratio
).
add_
(
image2
,
alpha
=
(
1.0
-
ratio
)).
clamp_
(
0
,
bound
)
return
output
if
fp
else
output
.
to
(
image1
.
dtype
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoints
.
Image
)
def
adjust_brightness_image_tensor
(
image
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
brightness_factor
<
0
:
raise
ValueError
(
f
"brightness_factor (
{
brightness_factor
}
) is not non-negative."
)
...
...
@@ -83,23 +106,27 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
return
output
if
fp
else
output
.
to
(
image
.
dtype
)
adjust_brightness_image_pil
=
_FP
.
adjust_brightness
def
adjust_brightness_image_pil
(
image
:
PIL
.
Image
.
Image
,
brightness_factor
:
float
)
->
PIL
.
Image
.
Image
:
return
_FP
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoints
.
Video
)
def
adjust_brightness_video
(
video
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_brightness_image_tensor
(
video
,
brightness_factor
=
brightness_factor
)
def
adjust_brightness
(
inpt
:
datapoints
.
_InputTypeJIT
,
brightness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_saturation
(
inpt
:
datapoints
.
_InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_
brightness
)
_log_api_usage_once
(
adjust_
saturation
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_brightness_image_tensor
(
inpt
,
brightness_factor
=
brightness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_brightness
(
brightness_factor
=
brightness_factor
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)):
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_saturation
,
type
(
inpt
))
return
kernel
(
inpt
,
saturation_factor
=
saturation_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_
brightness
_image_pil
(
inpt
,
brightness_factor
=
brightness
_factor
)
return
adjust_
saturation
_image_pil
(
inpt
,
saturation_factor
=
saturation
_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -107,6 +134,7 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float)
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoints
.
Image
)
def
adjust_saturation_image_tensor
(
image
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
saturation_factor
<
0
:
raise
ValueError
(
f
"saturation_factor (
{
saturation_factor
}
) is not non-negative."
)
...
...
@@ -128,22 +156,23 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
adjust_saturation_image_pil
=
_FP
.
adjust_saturation
@
_register_kernel_internal
(
adjust_saturation
,
datapoints
.
Video
)
def
adjust_saturation_video
(
video
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_saturation_image_tensor
(
video
,
saturation_factor
=
saturation_factor
)
def
adjust_saturation
(
inpt
:
datapoints
.
_InputTypeJIT
,
saturation_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_contrast
(
inpt
:
datapoints
.
_InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_
saturation
)
_log_api_usage_once
(
adjust_
contrast
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_saturation_image_tensor
(
inpt
,
saturation_factor
=
saturation_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_saturation
(
saturation_factor
=
saturation_factor
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_contrast
,
type
(
inpt
))
return
kernel
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_
saturation
_image_pil
(
inpt
,
saturation_factor
=
saturation
_factor
)
return
adjust_
contrast
_image_pil
(
inpt
,
contrast_factor
=
contrast
_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -151,6 +180,7 @@ def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float)
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoints
.
Image
)
def
adjust_contrast_image_tensor
(
image
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
contrast_factor
<
0
:
raise
ValueError
(
f
"contrast_factor (
{
contrast_factor
}
) is not non-negative."
)
...
...
@@ -172,20 +202,23 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
adjust_contrast_image_pil
=
_FP
.
adjust_contrast
@
_register_kernel_internal
(
adjust_contrast
,
datapoints
.
Video
)
def
adjust_contrast_video
(
video
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_contrast_image_tensor
(
video
,
contrast_factor
=
contrast_factor
)
def
adjust_contrast
(
inpt
:
datapoints
.
_InputTypeJIT
,
contrast_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_sharpness
(
inpt
:
datapoints
.
_InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_
contrast
)
_log_api_usage_once
(
adjust_
sharpness
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_contrast_image_tensor
(
inpt
,
contrast_factor
=
contrast_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_contrast
(
contrast_factor
=
contrast_factor
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)):
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_sharpness
,
type
(
inpt
))
return
kernel
(
inpt
,
sharpness_factor
=
sharpness_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_
contrast
_image_pil
(
inpt
,
contrast_factor
=
contrast
_factor
)
return
adjust_
sharpness
_image_pil
(
inpt
,
sharpness_factor
=
sharpness
_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -193,6 +226,7 @@ def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> d
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoints
.
Image
)
def
adjust_sharpness_image_tensor
(
image
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
num_channels
,
height
,
width
=
image
.
shape
[
-
3
:]
if
num_channels
not
in
(
1
,
3
):
...
...
@@ -248,22 +282,23 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
adjust_sharpness_image_pil
=
_FP
.
adjust_sharpness
@
_register_kernel_internal
(
adjust_sharpness
,
datapoints
.
Video
)
def
adjust_sharpness_video
(
video
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_sharpness_image_tensor
(
video
,
sharpness_factor
=
sharpness_factor
)
def
adjust_sharpness
(
inpt
:
datapoints
.
_InputTypeJIT
,
sharpness_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_hue
(
inpt
:
datapoints
.
_InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_
sharpness
)
_log_api_usage_once
(
adjust_
hue
)
if
isinstance
(
inpt
,
torch
.
Tensor
)
and
(
torch
.
jit
.
is_scripting
()
or
not
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
)
):
return
adjust_sharpness_image_tensor
(
inpt
,
sharpness_factor
=
sharpness_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_sharpness
(
sharpness_factor
=
sharpness_factor
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_hue
,
type
(
inpt
))
return
kernel
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_
sharpness
_image_pil
(
inpt
,
sharpness_factor
=
sharpness
_factor
)
return
adjust_
hue
_image_pil
(
inpt
,
hue_factor
=
hue
_factor
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -335,6 +370,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
return
(
a4
.
mul_
(
mask
.
unsqueeze
(
dim
=-
4
))).
sum
(
dim
=-
3
)
@
_register_kernel_internal
(
adjust_hue
,
datapoints
.
Image
)
def
adjust_hue_image_tensor
(
image
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
...
...
@@ -365,20 +401,23 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
adjust_hue_image_pil
=
_FP
.
adjust_hue
@
_register_kernel_internal
(
adjust_hue
,
datapoints
.
Video
)
def
adjust_hue_video
(
video
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_hue_image_tensor
(
video
,
hue_factor
=
hue_factor
)
def
adjust_hue
(
inpt
:
datapoints
.
_InputTypeJIT
,
hue_factor
:
float
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
adjust_gamma
(
inpt
:
datapoints
.
_InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_
hue
)
_log_api_usage_once
(
adjust_
gamma
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_hue_image_tensor
(
inpt
,
hue_factor
=
hue_factor
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_hue
(
hue_factor
=
hue_factor
)
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
adjust_gamma
,
type
(
inpt
))
return
kernel
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_
hue
_image_pil
(
inpt
,
hue_factor
=
hue_factor
)
return
adjust_
gamma
_image_pil
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -386,6 +425,7 @@ def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints.
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoints
.
Image
)
def
adjust_gamma_image_tensor
(
image
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1.0
)
->
torch
.
Tensor
:
if
gamma
<
0
:
raise
ValueError
(
"Gamma should be a non-negative real number"
)
...
...
@@ -408,20 +448,23 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
adjust_gamma_image_pil
=
_FP
.
adjust_gamma
@
_register_kernel_internal
(
adjust_gamma
,
datapoints
.
Video
)
def
adjust_gamma_video
(
video
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
return
adjust_gamma_image_tensor
(
video
,
gamma
=
gamma
,
gain
=
gain
)
def
adjust_gamma
(
inpt
:
datapoints
.
_InputTypeJIT
,
gamma
:
float
,
gain
:
float
=
1
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
posterize
(
inpt
:
datapoints
.
_InputTypeJIT
,
bits
:
int
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
adjust_gamma
)
_log_api_usage_once
(
posterize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
adjust_gamma_image_tensor
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
adjust_gamma
(
gamma
=
gamma
,
gain
=
gain
)
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
posterize
,
type
(
inpt
))
return
kernel
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
adjust_gamma
_image_pil
(
inpt
,
gamma
=
gamma
,
gain
=
gain
)
return
posterize
_image_pil
(
inpt
,
bits
=
bits
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -429,6 +472,7 @@ def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1)
)
@
_register_kernel_internal
(
posterize
,
datapoints
.
Image
)
def
posterize_image_tensor
(
image
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
levels
=
1
<<
bits
...
...
@@ -445,20 +489,23 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
posterize_image_pil
=
_FP
.
posterize
@
_register_kernel_internal
(
posterize
,
datapoints
.
Video
)
def
posterize_video
(
video
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
return
posterize_image_tensor
(
video
,
bits
=
bits
)
def
posterize
(
inpt
:
datapoints
.
_InputTypeJIT
,
bits
:
int
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
solarize
(
inpt
:
datapoints
.
_InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
poste
rize
)
_log_api_usage_once
(
sola
rize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
posterize_image_tensor
(
inpt
,
bits
=
bits
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
posterize
(
bits
=
bits
)
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
solarize
,
type
(
inpt
))
return
kernel
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
poste
rize_image_pil
(
inpt
,
bits
=
bits
)
return
sola
rize_image_pil
(
inpt
,
threshold
=
threshold
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -466,6 +513,7 @@ def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTyp
)
@
_register_kernel_internal
(
solarize
,
datapoints
.
Image
)
def
solarize_image_tensor
(
image
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
threshold
>
_max_value
(
image
.
dtype
):
raise
TypeError
(
f
"Threshold should be less or equal the maximum value of the dtype, but got
{
threshold
}
"
)
...
...
@@ -476,20 +524,25 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor
solarize_image_pil
=
_FP
.
solarize
@
_register_kernel_internal
(
solarize
,
datapoints
.
Video
)
def
solarize_video
(
video
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
return
solarize_image_tensor
(
video
,
threshold
=
threshold
)
def
solarize
(
inpt
:
datapoints
.
_InputTypeJIT
,
threshold
:
float
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
autocontrast
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
solarize
)
_log_api_usage_once
(
autocontrast
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
solarize_image_tensor
(
inpt
,
threshold
=
threshold
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
solarize
(
threshold
=
threshold
)
return
autocontrast_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
autocontrast
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
solarize
_image_pil
(
inpt
,
threshold
=
threshold
)
return
autocontrast
_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -497,6 +550,7 @@ def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._In
)
@
_register_kernel_internal
(
autocontrast
,
datapoints
.
Image
)
def
autocontrast_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
c
=
image
.
shape
[
-
3
]
if
c
not
in
[
1
,
3
]:
...
...
@@ -529,20 +583,25 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
autocontrast_image_pil
=
_FP
.
autocontrast
@
_register_kernel_internal
(
autocontrast
,
datapoints
.
Video
)
def
autocontrast_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
autocontrast_image_tensor
(
video
)
def
autocontrast
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
equalize
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
autocontrast
)
_log_api_usage_once
(
equalize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
autocontrast_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
autocontrast
()
return
equalize_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
equalize
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
autocontrast
_image_pil
(
inpt
)
return
equalize
_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -550,6 +609,7 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
)
@
_register_kernel_internal
(
equalize
,
datapoints
.
Image
)
def
equalize_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
numel
()
==
0
:
return
image
...
...
@@ -622,20 +682,25 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalize_image_pil
=
_FP
.
equalize
@
_register_kernel_internal
(
equalize
,
datapoints
.
Video
)
def
equalize_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
equalize_image_tensor
(
video
)
def
equalize
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)
def
invert
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
equalize
)
_log_api_usage_once
(
invert
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
equalize_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
equalize
()
return
invert_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
invert
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
equalize
_image_pil
(
inpt
)
return
invert
_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -643,6 +708,7 @@ def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
)
@
_register_kernel_internal
(
invert
,
datapoints
.
Image
)
def
invert_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
return
1.0
-
image
...
...
@@ -656,22 +722,6 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
invert_image_pil
=
_FP
.
invert
@
_register_kernel_internal
(
invert
,
datapoints
.
Video
)
def
invert_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
invert_image_tensor
(
video
)
def
invert
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
invert
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
invert_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
invert
()
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
invert_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
torchvision/transforms/v2/functional/_geometry.py
View file @
a893f313
import
math
import
numbers
import
warnings
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
PIL.Image
import
torch
...
...
@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
from
._meta
import
clamp_bounding_boxes
,
convert_format_bounding_boxes
,
get_size_image_pil
from
._utils
import
is_simple_tensor
from
._utils
import
(
_get_kernel
,
_register_explicit_noop
,
_register_five_ten_crop_kernel
,
_register_kernel_internal
,
is_simple_tensor
,
)
def
_check_interpolation
(
interpolation
:
Union
[
InterpolationMode
,
int
])
->
InterpolationMode
:
...
...
@@ -39,6 +45,27 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
return
interpolation
def
horizontal_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizontal_flip
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
horizontal_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
horizontal_flip
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
horizontal_flip_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Image
)
def
horizontal_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
...
...
@@ -47,6 +74,7 @@ def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return
_FP
.
hflip
(
image
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Mask
)
def
horizontal_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
horizontal_flip_image_tensor
(
mask
)
...
...
@@ -68,20 +96,32 @@ def horizontal_flip_bounding_boxes(
return
bounding_boxes
.
reshape
(
shape
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_horizontal_flip_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
)
->
datapoints
.
BoundingBoxes
:
output
=
horizontal_flip_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoints
.
Video
)
def
horizontal_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
horizontal_flip_image_tensor
(
video
)
def
horizont
al_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
def
vertic
al_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
horizont
al_flip
)
_log_api_usage_once
(
vertic
al_flip
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
horizontal_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
horizontal_flip
()
return
vertical_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
vertical_flip
,
type
(
inpt
))
return
kernel
(
inpt
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
horizont
al_flip_image_pil
(
inpt
)
return
vertic
al_flip_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -89,6 +129,7 @@ def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Image
)
def
vertical_flip_image_tensor
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
2
)
...
...
@@ -97,6 +138,7 @@ def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return
_FP
.
vflip
(
image
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Mask
)
def
vertical_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image_tensor
(
mask
)
...
...
@@ -118,25 +160,17 @@ def vertical_flip_bounding_boxes(
return
bounding_boxes
.
reshape
(
shape
)
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image_tensor
(
video
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_vertical_flip_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
)
->
datapoints
.
BoundingBoxes
:
output
=
vertical_flip_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
def
vertical_flip
(
inpt
:
datapoints
.
_InputTypeJIT
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
vertical_flip
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
vertical_flip_image_tensor
(
inpt
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
vertical_flip
()
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
vertical_flip_image_pil
(
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
vertical_flip
,
datapoints
.
Video
)
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image_tensor
(
video
)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
...
...
@@ -158,6 +192,32 @@ def _compute_resized_output_size(
return
__compute_resized_output_size
(
canvas_size
,
size
=
size
,
max_size
=
max_size
)
def
resize
(
inpt
:
datapoints
.
_InputTypeJIT
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
resize
,
type
(
inpt
))
return
kernel
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resize_image_pil
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_kernel_internal
(
resize
,
datapoints
.
Image
)
def
resize_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
...
...
@@ -274,6 +334,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return
output
@
_register_kernel_internal
(
resize
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_resize_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
)
->
datapoints
.
Mask
:
output
=
resize_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
max_size
=
max_size
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
def
resize_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
canvas_size
:
Tuple
[
int
,
int
],
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
...
...
@@ -292,6 +360,17 @@ def resize_bounding_boxes(
)
@
_register_kernel_internal
(
resize
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_resize_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
)
->
datapoints
.
BoundingBoxes
:
output
,
canvas_size
=
resize_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
inpt
.
canvas_size
,
size
,
max_size
=
max_size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
resize
,
datapoints
.
Video
)
def
resize_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
...
...
@@ -302,23 +381,54 @@ def resize_video(
return
resize_image_tensor
(
video
,
size
=
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
def
resiz
e
(
def
affin
e
(
inpt
:
datapoints
.
_InputTypeJIT
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resize
)
_log_api_usage_once
(
affine
)
# TODO: consider deprecating integers from angle and shear on the future
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
resize_image_tensor
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
resize
(
size
,
interpolation
=
interpolation
,
max_size
=
max_size
,
antialias
=
antialias
)
return
affine_image_tensor
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
affine
,
type
(
inpt
))
return
kernel
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
if
antialias
is
False
:
warnings
.
warn
(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
return
resize_image_pil
(
inpt
,
size
,
interpolation
=
interpolation
,
max_size
=
max_size
)
return
affine_image_pil
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -574,6 +684,7 @@ def _affine_grid(
return
output_grid
.
view
(
1
,
oh
,
ow
,
2
)
@
_register_kernel_internal
(
affine
,
datapoints
.
Image
)
def
affine_image_tensor
(
image
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
...
...
@@ -763,6 +874,29 @@ def affine_bounding_boxes(
return
out_box
@
_register_kernel_internal
(
affine
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_affine_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
BoundingBoxes
:
output
=
affine_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
center
=
center
,
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
def
affine_mask
(
mask
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
...
...
@@ -795,6 +929,30 @@ def affine_mask(
return
output
@
_register_kernel_internal
(
affine
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_affine_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
output
=
affine_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
fill
=
fill
,
center
=
center
,
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
@
_register_kernel_internal
(
affine
,
datapoints
.
Video
)
def
affine_video
(
video
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
...
...
@@ -817,46 +975,24 @@ def affine_video(
)
def
affin
e
(
def
rotat
e
(
inpt
:
datapoints
.
_InputTypeJIT
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
datapoints
.
_FillTypeJIT
=
Non
e
,
expand
:
bool
=
Fals
e
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
affin
e
)
_log_api_usage_once
(
rotat
e
)
# TODO: consider deprecating integers from angle and shear on the future
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
affine_image_tensor
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
affine
(
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
)
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
rotate
,
type
(
inpt
))
return
kernel
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
affine_image_pil
(
inpt
,
angle
,
translate
=
translate
,
scale
=
scale
,
shear
=
shear
,
interpolation
=
interpolation
,
fill
=
fill
,
center
=
center
,
)
return
rotate_image_pil
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -864,6 +1000,7 @@ def affine(
)
@
_register_kernel_internal
(
rotate
,
datapoints
.
Image
)
def
rotate_image_tensor
(
image
:
torch
.
Tensor
,
angle
:
float
,
...
...
@@ -951,6 +1088,21 @@ def rotate_bounding_boxes(
)
@
_register_kernel_internal
(
rotate
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_rotate_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
)
->
datapoints
.
BoundingBoxes
:
output
,
canvas_size
=
rotate_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
angle
=
angle
,
expand
=
expand
,
center
=
center
,
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
,
canvas_size
=
canvas_size
)
def
rotate_mask
(
mask
:
torch
.
Tensor
,
angle
:
float
,
...
...
@@ -979,6 +1131,20 @@ def rotate_mask(
return
output
@
_register_kernel_internal
(
rotate
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_rotate_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
output
=
rotate_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
@
_register_kernel_internal
(
rotate
,
datapoints
.
Video
)
def
rotate_video
(
video
:
torch
.
Tensor
,
angle
:
float
,
...
...
@@ -990,23 +1156,23 @@ def rotate_video(
return
rotate_image_tensor
(
video
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
def
rotate
(
def
pad
(
inpt
:
datapoints
.
_InputTypeJIT
,
angle
:
float
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
rotate
)
_log_api_usage_once
(
pad
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
rotate_image_tensor
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
rotate
(
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
pad_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
pad
,
type
(
inpt
))
return
kernel
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
rotate
_image_pil
(
inpt
,
angle
,
interpolation
=
interpolation
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
pad
_image_pil
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1038,6 +1204,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
return
[
pad_left
,
pad_right
,
pad_top
,
pad_bottom
]
@
_register_kernel_internal
(
pad
,
datapoints
.
Image
)
def
pad_image_tensor
(
image
:
torch
.
Tensor
,
padding
:
List
[
int
],
...
...
@@ -1139,6 +1306,7 @@ def _pad_with_vector_fill(
pad_image_pil
=
_FP
.
pad
@
_register_kernel_internal
(
pad
,
datapoints
.
Mask
)
def
pad_mask
(
mask
:
torch
.
Tensor
,
padding
:
List
[
int
],
...
...
@@ -1192,6 +1360,21 @@ def pad_bounding_boxes(
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
),
canvas_size
@
_register_kernel_internal
(
pad
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_pad_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
padding
:
List
[
int
],
padding_mode
:
str
=
"constant"
,
**
kwargs
)
->
datapoints
.
BoundingBoxes
:
output
,
canvas_size
=
pad_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
padding
=
padding
,
padding_mode
=
padding_mode
,
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
pad
,
datapoints
.
Video
)
def
pad_video
(
video
:
torch
.
Tensor
,
padding
:
List
[
int
],
...
...
@@ -1201,22 +1384,17 @@ def pad_video(
return
pad_image_tensor
(
video
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
def
pad
(
inpt
:
datapoints
.
_InputTypeJIT
,
padding
:
List
[
int
],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
datapoints
.
_InputTypeJIT
:
def
crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
pad
)
_log_api_usage_once
(
crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
pad
_image_tensor
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
pad
(
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
crop
_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
crop
,
type
(
inpt
))
return
kernel
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
pad
_image_pil
(
inpt
,
padding
,
fill
=
fill
,
padding_mode
=
padding_mode
)
return
crop
_image_pil
(
inpt
,
top
,
left
,
height
,
width
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1224,6 +1402,7 @@ def pad(
)
@
_register_kernel_internal
(
crop
,
datapoints
.
Image
)
def
crop_image_tensor
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
h
,
w
=
image
.
shape
[
-
2
:]
...
...
@@ -1266,6 +1445,17 @@ def crop_bounding_boxes(
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
),
canvas_size
@
_register_kernel_internal
(
crop
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_crop_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
BoundingBoxes
:
output
,
canvas_size
=
crop_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
crop
,
datapoints
.
Mask
)
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1281,20 +1471,32 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return
output
@
_register_kernel_internal
(
crop
,
datapoints
.
Video
)
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
return
crop_image_tensor
(
video
,
top
,
left
,
height
,
width
)
def
crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoints
.
_InputTypeJIT
:
def
perspective
(
inpt
:
datapoints
.
_InputTypeJIT
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
crop
)
_log_api_usage_once
(
perspective
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
crop
(
top
,
left
,
height
,
width
)
return
perspective_image_tensor
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
perspective
,
type
(
inpt
))
return
kernel
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
crop_image_pil
(
inpt
,
top
,
left
,
height
,
width
)
return
perspective_image_pil
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1349,6 +1551,7 @@ def _perspective_coefficients(
raise
ValueError
(
"Either the startpoints/endpoints or the coefficients must have non `None` values."
)
@
_register_kernel_internal
(
perspective
,
datapoints
.
Image
)
def
perspective_image_tensor
(
image
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1503,6 +1706,25 @@ def perspective_bounding_boxes(
).
reshape
(
original_shape
)
@
_register_kernel_internal
(
perspective
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_perspective_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
BoundingBoxes
:
output
=
perspective_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
startpoints
=
startpoints
,
endpoints
=
endpoints
,
coefficients
=
coefficients
,
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
def
perspective_mask
(
mask
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1526,6 +1748,26 @@ def perspective_mask(
return
output
@
_register_kernel_internal
(
perspective
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_perspective_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
datapoints
.
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoints
.
Mask
:
output
=
perspective_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
startpoints
=
startpoints
,
endpoints
=
endpoints
,
fill
=
fill
,
coefficients
=
coefficients
,
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
@
_register_kernel_internal
(
perspective
,
datapoints
.
Video
)
def
perspective_video
(
video
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1539,28 +1781,25 @@ def perspective_video(
)
def
perspective
(
def
elastic
(
inpt
:
datapoints
.
_InputTypeJIT
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
perspective
)
_log_api_usage_once
(
elastic
)
if
not
isinstance
(
displacement
,
torch
.
Tensor
):
raise
TypeError
(
"Argument displacement should be a Tensor"
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
perspective_image_tensor
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
perspective
(
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
elastic
,
type
(
inpt
))
return
kernel
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
perspective_image_pil
(
inpt
,
startpoints
,
endpoints
,
interpolation
=
interpolation
,
fill
=
fill
,
coefficients
=
coefficients
)
return
elastic_image_pil
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1568,6 +1807,10 @@ def perspective(
)
elastic_transform
=
elastic
@
_register_kernel_internal
(
elastic
,
datapoints
.
Image
)
def
elastic_image_tensor
(
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
...
...
@@ -1699,6 +1942,16 @@ def elastic_bounding_boxes(
).
reshape
(
original_shape
)
@
_register_kernel_internal
(
elastic
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_elastic_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
displacement
:
torch
.
Tensor
,
**
kwargs
)
->
datapoints
.
BoundingBoxes
:
output
=
elastic_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
displacement
=
displacement
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
)
def
elastic_mask
(
mask
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
...
...
@@ -1718,6 +1971,15 @@ def elastic_mask(
return
output
@
_register_kernel_internal
(
elastic
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_elastic_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
**
kwargs
)
->
datapoints
.
Mask
:
output
=
elastic_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
displacement
=
displacement
,
fill
=
fill
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
@
_register_kernel_internal
(
elastic
,
datapoints
.
Video
)
def
elastic_video
(
video
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
...
...
@@ -1727,24 +1989,17 @@ def elastic_video(
return
elastic_image_tensor
(
video
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
def
elastic
(
inpt
:
datapoints
.
_InputTypeJIT
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
datapoints
.
_FillTypeJIT
=
None
,
)
->
datapoints
.
_InputTypeJIT
:
def
center_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
elastic
)
if
not
isinstance
(
displacement
,
torch
.
Tensor
):
raise
TypeError
(
"Argument displacement should be a Tensor"
)
_log_api_usage_once
(
center_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
elastic_image_tensor
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
elastic
(
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
center_crop_image_tensor
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
center_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
elastic
_image_pil
(
inpt
,
displacement
,
interpolation
=
interpolation
,
fill
=
fill
)
return
center_crop
_image_pil
(
inpt
,
output_size
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1752,9 +2007,6 @@ def elastic(
)
elastic_transform
=
elastic
def
_center_crop_parse_output_size
(
output_size
:
List
[
int
])
->
List
[
int
]:
if
isinstance
(
output_size
,
numbers
.
Number
):
s
=
int
(
output_size
)
...
...
@@ -1782,6 +2034,7 @@ def _center_crop_compute_crop_anchor(
return
crop_top
,
crop_left
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Image
)
def
center_crop_image_tensor
(
image
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
shape
=
image
.
shape
...
...
@@ -1831,6 +2084,17 @@ def center_crop_bounding_boxes(
)
@
_register_kernel_internal
(
center_crop
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_center_crop_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
output_size
:
List
[
int
]
)
->
datapoints
.
BoundingBoxes
:
output
,
canvas_size
=
center_crop_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
output_size
=
output_size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Mask
)
def
center_crop_mask
(
mask
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1846,20 +2110,33 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return
output
@
_register_kernel_internal
(
center_crop
,
datapoints
.
Video
)
def
center_crop_video
(
video
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
return
center_crop_image_tensor
(
video
,
output_size
)
def
center_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
output_size
:
List
[
int
])
->
datapoints
.
_InputTypeJIT
:
def
resized_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
center
_crop
)
_log_api_usage_once
(
resized
_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
center_crop_image_tensor
(
inpt
,
output_size
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
center_crop
(
output_size
)
return
resized_crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
resized_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
center
_crop_image_pil
(
inpt
,
output_size
)
return
resized
_crop_image_pil
(
inpt
,
top
,
left
,
height
,
width
,
size
=
size
,
interpolation
=
interpolation
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1867,6 +2144,7 @@ def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datap
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Image
)
def
resized_crop_image_tensor
(
image
:
torch
.
Tensor
,
top
:
int
,
...
...
@@ -1904,8 +2182,18 @@ def resized_crop_bounding_boxes(
width
:
int
,
size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
bounding_boxes
,
_
=
crop_bounding_boxes
(
bounding_boxes
,
format
,
top
,
left
,
height
,
width
)
return
resize_bounding_boxes
(
bounding_boxes
,
canvas_size
=
(
height
,
width
),
size
=
size
)
bounding_boxes
,
canvas_size
=
crop_bounding_boxes
(
bounding_boxes
,
format
,
top
,
left
,
height
,
width
)
return
resize_bounding_boxes
(
bounding_boxes
,
canvas_size
=
canvas_size
,
size
=
size
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
BoundingBoxes
,
datapoint_wrapper
=
False
)
def
_resized_crop_bounding_boxes_dispatch
(
inpt
:
datapoints
.
BoundingBoxes
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
**
kwargs
)
->
datapoints
.
BoundingBoxes
:
output
,
canvas_size
=
resized_crop_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
return
datapoints
.
BoundingBoxes
.
wrap_like
(
inpt
,
output
,
canvas_size
=
canvas_size
)
def
resized_crop_mask
(
...
...
@@ -1920,6 +2208,17 @@ def resized_crop_mask(
return
resize_mask
(
mask
,
size
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Mask
,
datapoint_wrapper
=
False
)
def
_resized_crop_mask_dispatch
(
inpt
:
datapoints
.
Mask
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
**
kwargs
)
->
datapoints
.
Mask
:
output
=
resized_crop_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
return
datapoints
.
Mask
.
wrap_like
(
inpt
,
output
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Video
)
def
resized_crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
...
...
@@ -1935,27 +2234,26 @@ def resized_crop_video(
)
def
resized_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
]
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
datapoints
.
_InputTypeJIT
:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
five_crop
(
inpt
:
datapoints
.
_InputTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
resized
_crop
)
_log_api_usage_once
(
five
_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
resized_crop_image_tensor
(
inpt
,
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
elif
isinstance
(
inpt
,
datapoints
.
_datapoint
.
Datapoint
):
return
inpt
.
resized_crop
(
top
,
left
,
height
,
width
,
antialias
=
antialias
,
size
=
size
,
interpolation
=
interpolation
)
return
five_crop_image_tensor
(
inpt
,
size
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
five_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
size
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
resized
_crop_image_pil
(
inpt
,
top
,
left
,
height
,
width
,
size
=
size
,
interpolation
=
interpolation
)
return
five
_crop_image_pil
(
inpt
,
size
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...
...
@@ -1977,6 +2275,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return
size
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Image
)
def
five_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -2014,38 +2313,46 @@ def five_crop_image_pil(
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel
(
five_crop
,
datapoints
.
Video
)
def
five_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
return
five_crop_image_tensor
(
video
,
size
)
ImageOrVideoTypeJIT
=
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
]
def
five_crop
(
inpt
:
ImageOrVideoTypeJIT
,
size
:
List
[
int
]
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
]:
@
_register_explicit_noop
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
,
warn_passthrough
=
True
)
def
ten_crop
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
datapoints
.
_InputTypeJIT
,
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
five
_crop
)
_log_api_usage_once
(
ten
_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
five_crop_image_tensor
(
inpt
,
size
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
five_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
return
tuple
(
datapoints
.
Image
.
wrap_like
(
inpt
,
item
)
for
item
in
output
)
# type: ignore[return-value]
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
five_crop_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
)
return
tuple
(
datapoints
.
Video
.
wrap_like
(
inpt
,
item
)
for
item
in
output
)
# type: ignore[return-value]
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
datapoints
.
Datapoint
):
kernel
=
_get_kernel
(
ten_crop
,
type
(
inpt
))
return
kernel
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
five
_crop_image_pil
(
inpt
,
size
)
return
ten
_crop_image_pil
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an
`Image` or `Video`
datapoint, or a PIL image, "
f
"Input can either be a plain tensor, an
y TorchVision
datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Image
)
def
ten_crop_image_tensor
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
@@ -2099,6 +2406,7 @@ def ten_crop_image_pil(
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel
(
ten_crop
,
datapoints
.
Video
)
def
ten_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
@@ -2114,37 +2422,3 @@ def ten_crop_video(
torch
.
Tensor
,
]:
return
ten_crop_image_tensor
(
video
,
size
,
vertical_flip
=
vertical_flip
)
def
ten_crop
(
inpt
:
Union
[
datapoints
.
_ImageTypeJIT
,
datapoints
.
_VideoTypeJIT
],
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
ImageOrVideoTypeJIT
,
]:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
ten_crop
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
return
ten_crop_image_tensor
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
elif
isinstance
(
inpt
,
datapoints
.
Image
):
output
=
ten_crop_image_tensor
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
return
tuple
(
datapoints
.
Image
.
wrap_like
(
inpt
,
item
)
for
item
in
output
)
# type: ignore[return-value]
elif
isinstance
(
inpt
,
datapoints
.
Video
):
output
=
ten_crop_video
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
vertical_flip
=
vertical_flip
)
return
tuple
(
datapoints
.
Video
.
wrap_like
(
inpt
,
item
)
for
item
in
output
)
# type: ignore[return-value]
elif
isinstance
(
inpt
,
PIL
.
Image
.
Image
):
return
ten_crop_image_pil
(
inpt
,
size
,
vertical_flip
=
vertical_flip
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f
"but got
{
type
(
inpt
)
}
instead."
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment