Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f244e27e
Unverified
Commit
f244e27e
authored
Aug 15, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 15, 2023
Browse files
Dispatcher -> Functional (#7829)
parent
6ab8a96f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
99 additions
and
99 deletions
+99
-99
gallery/plot_custom_datapoints.py
gallery/plot_custom_datapoints.py
+3
-3
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+64
-64
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+2
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+4
-4
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+2
-2
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+1
-1
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+23
-23
No files found.
gallery/plot_custom_datapoints.py
View file @
f244e27e
...
@@ -49,7 +49,7 @@ my_dp
...
@@ -49,7 +49,7 @@ my_dp
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2
import
functional
as
F
@
F
.
register_kernel
(
dispatcher
=
"hflip"
,
datapoint_cls
=
MyDatapoint
)
@
F
.
register_kernel
(
functional
=
"hflip"
,
datapoint_cls
=
MyDatapoint
)
def
hflip_my_datapoint
(
my_dp
,
*
args
,
**
kwargs
):
def
hflip_my_datapoint
(
my_dp
,
*
args
,
**
kwargs
):
print
(
"Flipping!"
)
print
(
"Flipping!"
)
out
=
my_dp
.
flip
(
-
1
)
out
=
my_dp
.
flip
(
-
1
)
...
@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
...
@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# .. note::
# .. note::
#
#
# In our call to ``register_kernel`` above we used a string
# In our call to ``register_kernel`` above we used a string
# ``
dispatcher
="hflip"`` to refer to the functional we want to hook into. We
# ``
functional
="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e.
# could also have used the functional *itself*, i.e.
# ``@register_kernel(
dispatcher
=F.hflip, ...)``.
# ``@register_kernel(
functional
=F.hflip, ...)``.
#
#
# The functionals that you can be hooked into are the ones in
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
# ``torchvision.transforms.v2.functional`` and they are documented in
...
...
test/test_transforms_v2_refactored.py
View file @
f244e27e
...
@@ -163,25 +163,25 @@ def check_kernel(
...
@@ -163,25 +163,25 @@ def check_kernel(
_check_kernel_batched_vs_unbatched
(
kernel
,
input
,
*
args
,
**
kwargs
,
**
_to_tolerances
(
check_batched_vs_unbatched
))
_check_kernel_batched_vs_unbatched
(
kernel
,
input
,
*
args
,
**
kwargs
,
**
_to_tolerances
(
check_batched_vs_unbatched
))
def
_check_
dispatcher
_scripted_smoke
(
dispatcher
,
input
,
*
args
,
**
kwargs
):
def
_check_
functional
_scripted_smoke
(
functional
,
input
,
*
args
,
**
kwargs
):
"""Checks if the
dispatcher
can be scripted and the scripted version can be called without error."""
"""Checks if the
functional
can be scripted and the scripted version can be called without error."""
if
not
isinstance
(
input
,
datapoints
.
Image
):
if
not
isinstance
(
input
,
datapoints
.
Image
):
return
return
dispatcher
_scripted
=
_script
(
dispatcher
)
functional
_scripted
=
_script
(
functional
)
with
ignore_jit_no_profile_information_warning
():
with
ignore_jit_no_profile_information_warning
():
dispatcher
_scripted
(
input
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
functional
_scripted
(
input
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
def
check_
dispatcher
(
dispatcher
,
input
,
*
args
,
check_scripted_smoke
=
True
,
**
kwargs
):
def
check_
functional
(
functional
,
input
,
*
args
,
check_scripted_smoke
=
True
,
**
kwargs
):
unknown_input
=
object
()
unknown_input
=
object
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
dispatcher
(
unknown_input
,
*
args
,
**
kwargs
)
functional
(
unknown_input
,
*
args
,
**
kwargs
)
with
mock
.
patch
(
"torch._C._log_api_usage_once"
,
wraps
=
torch
.
_C
.
_log_api_usage_once
)
as
spy
:
with
mock
.
patch
(
"torch._C._log_api_usage_once"
,
wraps
=
torch
.
_C
.
_log_api_usage_once
)
as
spy
:
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
output
=
functional
(
input
,
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
dispatcher
.
__module__
}
.
{
dispatcher
.
__name__
}
"
)
spy
.
assert_any_call
(
f
"
{
functional
.
__module__
}
.
{
functional
.
__name__
}
"
)
assert
isinstance
(
output
,
type
(
input
))
assert
isinstance
(
output
,
type
(
input
))
...
@@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar
...
@@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar
assert
output
.
format
==
input
.
format
assert
output
.
format
==
input
.
format
if
check_scripted_smoke
:
if
check_scripted_smoke
:
_check_
dispatcher
_scripted_smoke
(
dispatcher
,
input
,
*
args
,
**
kwargs
)
_check_
functional
_scripted_smoke
(
functional
,
input
,
*
args
,
**
kwargs
)
def
check_
dispatcher
_kernel_signature_match
(
dispatcher
,
*
,
kernel
,
input_type
):
def
check_
functional
_kernel_signature_match
(
functional
,
*
,
kernel
,
input_type
):
"""Checks if the signature of the
dispatcher
matches the kernel signature."""
"""Checks if the signature of the
functional
matches the kernel signature."""
dispatcher
_params
=
list
(
inspect
.
signature
(
dispatcher
).
parameters
.
values
())[
1
:]
functional
_params
=
list
(
inspect
.
signature
(
functional
).
parameters
.
values
())[
1
:]
kernel_params
=
list
(
inspect
.
signature
(
kernel
).
parameters
.
values
())[
1
:]
kernel_params
=
list
(
inspect
.
signature
(
kernel
).
parameters
.
values
())[
1
:]
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
# We filter out metadata that is implicitly passed to the
dispatcher
through the input datapoint, but has to be
# We filter out metadata that is implicitly passed to the
functional
through the input datapoint, but has to be
# explicitly passed to the kernel.
# explicitly passed to the kernel.
explicit_metadata
=
{
explicit_metadata
=
{
datapoints
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
datapoints
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
dispatcher
_params
=
iter
(
dispatcher
_params
)
functional
_params
=
iter
(
functional
_params
)
for
dispatcher
_param
,
kernel_param
in
zip
(
dispatcher
_params
,
kernel_params
):
for
functional
_param
,
kernel_param
in
zip
(
functional
_params
,
kernel_params
):
try
:
try
:
# In general, the
dispatcher
parameters are a superset of the kernel parameters. Thus, we filter out
# In general, the
functional
parameters are a superset of the kernel parameters. Thus, we filter out
#
dispatcher
parameters that have no kernel equivalent while keeping the order intact.
#
functional
parameters that have no kernel equivalent while keeping the order intact.
while
dispatcher
_param
.
name
!=
kernel_param
.
name
:
while
functional
_param
.
name
!=
kernel_param
.
name
:
dispatcher
_param
=
next
(
dispatcher
_params
)
functional
_param
=
next
(
functional
_params
)
except
StopIteration
:
except
StopIteration
:
raise
AssertionError
(
raise
AssertionError
(
f
"Parameter `
{
kernel_param
.
name
}
` of kernel `
{
kernel
.
__name__
}
` "
f
"Parameter `
{
kernel_param
.
name
}
` of kernel `
{
kernel
.
__name__
}
` "
f
"has no corresponding parameter on the
dispatcher `
{
dispatcher
.
__name__
}
`."
f
"has no corresponding parameter on the
functional `
{
functional
.
__name__
}
`."
)
from
None
)
from
None
if
issubclass
(
input_type
,
PIL
.
Image
.
Image
):
if
issubclass
(
input_type
,
PIL
.
Image
.
Image
):
# PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
# PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
# them in the first place.
# them in the first place.
dispatcher
_param
.
_annotation
=
kernel_param
.
_annotation
=
inspect
.
Parameter
.
empty
functional
_param
.
_annotation
=
kernel_param
.
_annotation
=
inspect
.
Parameter
.
empty
assert
dispatcher
_param
==
kernel_param
assert
functional
_param
==
kernel_param
def
_check_transform_v1_compatibility
(
transform
,
input
):
def
_check_transform_v1_compatibility
(
transform
,
input
):
...
@@ -482,8 +482,8 @@ class TestResize:
...
@@ -482,8 +482,8 @@ class TestResize:
"make_input"
,
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
)
def
test_
dispatcher
(
self
,
size
,
make_input
):
def
test_
functional
(
self
,
size
,
make_input
):
check_
dispatcher
(
check_
functional
(
F
.
resize
,
F
.
resize
,
make_input
(
self
.
INPUT_SIZE
),
make_input
(
self
.
INPUT_SIZE
),
size
=
size
,
size
=
size
,
...
@@ -502,8 +502,8 @@ class TestResize:
...
@@ -502,8 +502,8 @@ class TestResize:
(
F
.
resize_video
,
datapoints
.
Video
),
(
F
.
resize_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
resize
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
resize
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
...
@@ -608,7 +608,7 @@ class TestResize:
...
@@ -608,7 +608,7 @@ class TestResize:
interpolation
=
interpolation
,
interpolation
=
interpolation
,
)
)
def
test_
dispatcher
_pil_antialias_warning
(
self
):
def
test_
functional
_pil_antialias_warning
(
self
):
with
pytest
.
warns
(
UserWarning
,
match
=
"Anti-alias option is always applied for PIL Image input"
):
with
pytest
.
warns
(
UserWarning
,
match
=
"Anti-alias option is always applied for PIL Image input"
):
F
.
resize
(
make_image_pil
(
self
.
INPUT_SIZE
),
size
=
self
.
OUTPUT_SIZES
[
0
],
antialias
=
False
)
F
.
resize
(
make_image_pil
(
self
.
INPUT_SIZE
),
size
=
self
.
OUTPUT_SIZES
[
0
],
antialias
=
False
)
...
@@ -763,8 +763,8 @@ class TestHorizontalFlip:
...
@@ -763,8 +763,8 @@ class TestHorizontalFlip:
"make_input"
,
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
)
def
test_
dispatcher
(
self
,
make_input
):
def
test_
functional
(
self
,
make_input
):
check_
dispatcher
(
F
.
horizontal_flip
,
make_input
())
check_
functional
(
F
.
horizontal_flip
,
make_input
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
...
@@ -777,8 +777,8 @@ class TestHorizontalFlip:
...
@@ -777,8 +777,8 @@ class TestHorizontalFlip:
(
F
.
horizontal_flip_video
,
datapoints
.
Video
),
(
F
.
horizontal_flip_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
horizontal_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
horizontal_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"make_input"
,
"make_input"
,
...
@@ -939,8 +939,8 @@ class TestAffine:
...
@@ -939,8 +939,8 @@ class TestAffine:
"make_input"
,
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
)
def
test_
dispatcher
(
self
,
make_input
):
def
test_
functional
(
self
,
make_input
):
check_
dispatcher
(
F
.
affine
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
check_
functional
(
F
.
affine
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
...
@@ -953,8 +953,8 @@ class TestAffine:
...
@@ -953,8 +953,8 @@ class TestAffine:
(
F
.
affine_video
,
datapoints
.
Video
),
(
F
.
affine_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
affine
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
affine
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"make_input"
,
"make_input"
,
...
@@ -1228,8 +1228,8 @@ class TestVerticalFlip:
...
@@ -1228,8 +1228,8 @@ class TestVerticalFlip:
"make_input"
,
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
)
def
test_
dispatcher
(
self
,
make_input
):
def
test_
functional
(
self
,
make_input
):
check_
dispatcher
(
F
.
vertical_flip
,
make_input
())
check_
functional
(
F
.
vertical_flip
,
make_input
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
...
@@ -1242,8 +1242,8 @@ class TestVerticalFlip:
...
@@ -1242,8 +1242,8 @@ class TestVerticalFlip:
(
F
.
vertical_flip_video
,
datapoints
.
Video
),
(
F
.
vertical_flip_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
vertical_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
vertical_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"make_input"
,
"make_input"
,
...
@@ -1378,8 +1378,8 @@ class TestRotate:
...
@@ -1378,8 +1378,8 @@ class TestRotate:
"make_input"
,
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
)
def
test_
dispatcher
(
self
,
make_input
):
def
test_
functional
(
self
,
make_input
):
check_
dispatcher
(
F
.
rotate
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
check_
functional
(
F
.
rotate
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
...
@@ -1392,8 +1392,8 @@ class TestRotate:
...
@@ -1392,8 +1392,8 @@ class TestRotate:
(
F
.
rotate_video
,
datapoints
.
Video
),
(
F
.
rotate_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
rotate
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
rotate
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"make_input"
,
"make_input"
,
...
@@ -1643,8 +1643,8 @@ class TestToDtype:
...
@@ -1643,8 +1643,8 @@ class TestToDtype:
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
[
torch
.
float32
,
torch
.
float64
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
[
torch
.
float32
,
torch
.
float64
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
True
,
False
))
def
test_
dispatcher
(
self
,
make_input
,
input_dtype
,
output_dtype
,
device
,
scale
):
def
test_
functional
(
self
,
make_input
,
input_dtype
,
output_dtype
,
device
,
scale
):
check_
dispatcher
(
check_
functional
(
F
.
to_dtype
,
F
.
to_dtype
,
make_input
(
dtype
=
input_dtype
,
device
=
device
),
make_input
(
dtype
=
input_dtype
,
device
=
device
),
dtype
=
output_dtype
,
dtype
=
output_dtype
,
...
@@ -1810,8 +1810,8 @@ class TestAdjustBrightness:
...
@@ -1810,8 +1810,8 @@ class TestAdjustBrightness:
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
def
test_
dispatcher
(
self
,
make_input
):
def
test_
functional
(
self
,
make_input
):
check_
dispatcher
(
F
.
adjust_brightness
,
make_input
(),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
check_
functional
(
F
.
adjust_brightness
,
make_input
(),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
...
@@ -1822,8 +1822,8 @@ class TestAdjustBrightness:
...
@@ -1822,8 +1822,8 @@ class TestAdjustBrightness:
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
adjust_brightness
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
adjust_brightness
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"brightness_factor"
,
_CORRECTNESS_BRIGHTNESS_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"brightness_factor"
,
_CORRECTNESS_BRIGHTNESS_FACTORS
)
def
test_image_correctness
(
self
,
brightness_factor
):
def
test_image_correctness
(
self
,
brightness_factor
):
...
@@ -2042,7 +2042,7 @@ class TestShapeGetters:
...
@@ -2042,7 +2042,7 @@ class TestShapeGetters:
assert
kernel
(
input
)
==
F
.
get_num_frames
(
input
)
==
num_frames
assert
kernel
(
input
)
==
F
.
get_num_frames
(
input
)
==
num_frames
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"
dispatcher
"
,
"make_input"
),
(
"
functional
"
,
"make_input"
),
[
[
(
F
.
get_dimensions
,
make_bounding_box
),
(
F
.
get_dimensions
,
make_bounding_box
),
(
F
.
get_dimensions
,
make_detection_mask
),
(
F
.
get_dimensions
,
make_detection_mask
),
...
@@ -2057,22 +2057,22 @@ class TestShapeGetters:
...
@@ -2057,22 +2057,22 @@ class TestShapeGetters:
(
F
.
get_num_frames
,
make_segmentation_mask
),
(
F
.
get_num_frames
,
make_segmentation_mask
),
],
],
)
)
def
test_unsupported_types
(
self
,
dispatcher
,
make_input
):
def
test_unsupported_types
(
self
,
functional
,
make_input
):
input
=
make_input
()
input
=
make_input
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
input
)))):
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
input
)))):
dispatcher
(
input
)
functional
(
input
)
class
TestRegisterKernel
:
class
TestRegisterKernel
:
@
pytest
.
mark
.
parametrize
(
"
dispatcher
"
,
(
F
.
resize
,
"resize"
))
@
pytest
.
mark
.
parametrize
(
"
functional
"
,
(
F
.
resize
,
"resize"
))
def
test_register_kernel
(
self
,
dispatcher
):
def
test_register_kernel
(
self
,
functional
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
pass
kernel_was_called
=
False
kernel_was_called
=
False
@
F
.
register_kernel
(
dispatcher
,
CustomDatapoint
)
@
F
.
register_kernel
(
functional
,
CustomDatapoint
)
def
new_resize
(
dp
,
*
args
,
**
kwargs
):
def
new_resize
(
dp
,
*
args
,
**
kwargs
):
nonlocal
kernel_was_called
nonlocal
kernel_was_called
kernel_was_called
=
True
kernel_was_called
=
True
...
@@ -2090,10 +2090,10 @@ class TestRegisterKernel:
...
@@ -2090,10 +2090,10 @@ class TestRegisterKernel:
t
(
datapoints
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
t
(
datapoints
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
def
test_errors
(
self
):
def
test_errors
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find
dispatcher
with name"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find
functional
with name"
):
F
.
register_kernel
(
"bad_name"
,
datapoints
.
Image
)
F
.
register_kernel
(
"bad_name"
,
datapoints
.
Image
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on
dispatcher
s"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on
functional
s"
):
F
.
register_kernel
(
datapoints
.
Image
,
F
.
resize
)
F
.
register_kernel
(
datapoints
.
Image
,
F
.
resize
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
...
@@ -2115,7 +2115,7 @@ class TestRegisterKernel:
...
@@ -2115,7 +2115,7 @@ class TestRegisterKernel:
class
TestGetKernel
:
class
TestGetKernel
:
# We are using F.resize as
dispatcher
and the kernels below as proxy. Any other
dispatcher
/ kernels combination
# We are using F.resize as
functional
and the kernels below as proxy. Any other
functional
/ kernels combination
# would also be fine
# would also be fine
KERNELS
=
{
KERNELS
=
{
torch
.
Tensor
:
F
.
resize_image_tensor
,
torch
.
Tensor
:
F
.
resize_image_tensor
,
...
@@ -2139,7 +2139,7 @@ class TestGetKernel:
...
@@ -2139,7 +2139,7 @@ class TestGetKernel:
def
test_exact_match
(
self
):
def
test_exact_match
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
dispatcher
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
functional
# here, register the kernels without wrapper, and check the exact matching afterwards.
# here, register the kernels without wrapper, and check the exact matching afterwards.
def
resize_with_pure_kernels
():
def
resize_with_pure_kernels
():
pass
pass
...
@@ -2151,7 +2151,7 @@ class TestGetKernel:
...
@@ -2151,7 +2151,7 @@ class TestGetKernel:
def
test_builtin_datapoint_subclass
(
self
):
def
test_builtin_datapoint_subclass
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
dispatcher
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
functional
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
# to the kernel of the corresponding superclass
def
resize_with_pure_kernels
():
def
resize_with_pure_kernels
():
...
@@ -2217,8 +2217,8 @@ class TestPermuteChannels:
...
@@ -2217,8 +2217,8 @@ class TestPermuteChannels:
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
def
test_
dispatcher
(
self
,
make_input
):
def
test_
functional
(
self
,
make_input
):
check_
dispatcher
(
F
.
permute_channels
,
make_input
(),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
check_
functional
(
F
.
permute_channels
,
make_input
(),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
(
"kernel"
,
"input_type"
),
...
@@ -2229,8 +2229,8 @@ class TestPermuteChannels:
...
@@ -2229,8 +2229,8 @@ class TestPermuteChannels:
(
F
.
permute_channels_video
,
datapoints
.
Video
),
(
F
.
permute_channels_video
,
datapoints
.
Video
),
],
],
)
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
permute_channels
,
kernel
=
kernel
,
input_type
=
input_type
)
check_
functional
_kernel_signature_match
(
F
.
permute_channels
,
kernel
=
kernel
,
input_type
=
input_type
)
def
reference_image_correctness
(
self
,
image
,
permutation
):
def
reference_image_correctness
(
self
,
image
,
permutation
):
channel_images
=
image
.
split
(
1
,
dim
=-
3
)
channel_images
=
image
.
split
(
1
,
dim
=-
3
)
...
...
torchvision/transforms/v2/_augment.py
View file @
f244e27e
...
@@ -91,13 +91,13 @@ class RandomErasing(_RandomApplyTransform):
...
@@ -91,13 +91,13 @@ class RandomErasing(_RandomApplyTransform):
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
img_c
,
img_h
,
img_w
=
query_chw
(
flat_inputs
)
img_c
,
img_h
,
img_w
=
query_chw
(
flat_inputs
)
...
...
torchvision/transforms/v2/_geometry.py
View file @
f244e27e
...
@@ -358,13 +358,13 @@ class FiveCrop(Transform):
...
@@ -358,13 +358,13 @@ class FiveCrop(Transform):
super
().
__init__
()
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
self
.
_call_kernel
(
F
.
five_crop
,
inpt
,
self
.
size
)
return
self
.
_call_kernel
(
F
.
five_crop
,
inpt
,
self
.
size
)
...
@@ -405,13 +405,13 @@ class TenCrop(Transform):
...
@@ -405,13 +405,13 @@ class TenCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
vertical_flip
=
vertical_flip
self
.
vertical_flip
=
vertical_flip
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
...
...
torchvision/transforms/v2/_transform.py
View file @
f244e27e
...
@@ -30,8 +30,8 @@ class Transform(nn.Module):
...
@@ -30,8 +30,8 @@ class Transform(nn.Module):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
return
dict
()
return
dict
()
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
kernel
=
_get_kernel
(
dispatcher
,
type
(
inpt
),
allow_passthrough
=
True
)
kernel
=
_get_kernel
(
functional
,
type
(
inpt
),
allow_passthrough
=
True
)
return
kernel
(
inpt
,
*
args
,
**
kwargs
)
return
kernel
(
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
f244e27e
...
@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
...
@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# This being a kernel /
dispatcher
hybrid, we need an option to pass `old_format` explicitly for simple tensor
# This being a kernel /
functional
hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
# default error that would be thrown if `new_format` had no default value.
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
f244e27e
...
@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool:
...
@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
# {
dispatcher
: {input_type: type_specific_kernel}}
# {
functional
: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY
:
Dict
[
Callable
,
Dict
[
Type
,
Callable
]]
=
{}
_KERNEL_REGISTRY
:
Dict
[
Callable
,
Dict
[
Type
,
Callable
]]
=
{}
...
@@ -27,10 +27,10 @@ def _kernel_datapoint_wrapper(kernel):
...
@@ -27,10 +27,10 @@ def _kernel_datapoint_wrapper(kernel):
return
wrapper
return
wrapper
def
_register_kernel_internal
(
dispatcher
,
input_type
,
*
,
datapoint_wrapper
=
True
):
def
_register_kernel_internal
(
functional
,
input_type
,
*
,
datapoint_wrapper
=
True
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
registry
=
_KERNEL_REGISTRY
.
setdefault
(
functional
,
{})
if
input_type
in
registry
:
if
input_type
in
registry
:
raise
ValueError
(
f
"
Dispatcher
{
dispatcher
}
already has a kernel registered for type
{
input_type
}
."
)
raise
ValueError
(
f
"
Functional
{
functional
}
already has a kernel registered for type
{
input_type
}
."
)
def
decorator
(
kernel
):
def
decorator
(
kernel
):
registry
[
input_type
]
=
(
registry
[
input_type
]
=
(
...
@@ -43,14 +43,14 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True)
...
@@ -43,14 +43,14 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True)
return
decorator
return
decorator
def
_name_to_
dispatcher
(
name
):
def
_name_to_
functional
(
name
):
import
torchvision.transforms.v2.functional
# noqa
import
torchvision.transforms.v2.functional
# noqa
try
:
try
:
return
getattr
(
torchvision
.
transforms
.
v2
.
functional
,
name
)
return
getattr
(
torchvision
.
transforms
.
v2
.
functional
,
name
)
except
AttributeError
:
except
AttributeError
:
raise
ValueError
(
raise
ValueError
(
f
"Could not find
dispatcher
with name '
{
name
}
' in torchvision.transforms.v2.functional."
f
"Could not find
functional
with name '
{
name
}
' in torchvision.transforms.v2.functional."
)
from
None
)
from
None
...
@@ -59,21 +59,21 @@ _BUILTIN_DATAPOINT_TYPES = {
...
@@ -59,21 +59,21 @@ _BUILTIN_DATAPOINT_TYPES = {
}
}
def
register_kernel
(
dispatcher
,
datapoint_cls
):
def
register_kernel
(
functional
,
datapoint_cls
):
"""Decorate a kernel to register it for a
dispatcher
and a (custom) datapoint type.
"""Decorate a kernel to register it for a
functional
and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details.
details.
"""
"""
if
isinstance
(
dispatcher
,
str
):
if
isinstance
(
functional
,
str
):
dispatcher
=
_name_to_
dispatcher
(
name
=
dispatcher
)
functional
=
_name_to_
functional
(
name
=
functional
)
elif
not
(
elif
not
(
callable
(
dispatcher
)
callable
(
functional
)
and
getattr
(
dispatcher
,
"__module__"
,
""
).
startswith
(
"torchvision.transforms.v2.functional"
)
and
getattr
(
functional
,
"__module__"
,
""
).
startswith
(
"torchvision.transforms.v2.functional"
)
):
):
raise
ValueError
(
raise
ValueError
(
f
"Kernels can only be registered on
dispatcher
s from the torchvision.transforms.v2.functional namespace, "
f
"Kernels can only be registered on
functional
s from the torchvision.transforms.v2.functional namespace, "
f
"but got
{
dispatcher
}
."
f
"but got
{
functional
}
."
)
)
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)):
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)):
...
@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls):
...
@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls):
if
datapoint_cls
in
_BUILTIN_DATAPOINT_TYPES
:
if
datapoint_cls
in
_BUILTIN_DATAPOINT_TYPES
:
raise
ValueError
(
f
"Kernels cannot be registered for the builtin datapoint classes, but got
{
datapoint_cls
}
"
)
raise
ValueError
(
f
"Kernels cannot be registered for the builtin datapoint classes, but got
{
datapoint_cls
}
"
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
return
_register_kernel_internal
(
functional
,
datapoint_cls
,
datapoint_wrapper
=
False
)
def
_get_kernel
(
dispatcher
,
input_type
,
*
,
allow_passthrough
=
False
):
def
_get_kernel
(
functional
,
input_type
,
*
,
allow_passthrough
=
False
):
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
registry
=
_KERNEL_REGISTRY
.
get
(
functional
)
if
not
registry
:
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for
dispatcher
{
dispatcher
.
__name__
}
."
)
raise
ValueError
(
f
"No kernel registered for
functional
{
functional
.
__name__
}
."
)
# In case we have an exact type match, we take a shortcut.
# In case we have an exact type match, we take a shortcut.
if
input_type
in
registry
:
if
input_type
in
registry
:
...
@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
...
@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
raise
TypeError
(
raise
TypeError
(
f
"
Dispatcher F.
{
dispatcher
.
__name__
}
supports inputs of type
{
registry
.
keys
()
}
, "
f
"
Functional F.
{
functional
.
__name__
}
supports inputs of type
{
registry
.
keys
()
}
, "
f
"but got
{
input_type
}
instead."
f
"but got
{
input_type
}
instead."
)
)
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary
dispatcher
s rather than wrap_kernel: bool
# We could get rid of this by letting _register_kernel_internal take arbitrary
functional
s rather than wrap_kernel: bool
def
_register_five_ten_crop_kernel_internal
(
dispatcher
,
input_type
):
def
_register_five_ten_crop_kernel_internal
(
functional
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
registry
=
_KERNEL_REGISTRY
.
setdefault
(
functional
,
{})
if
input_type
in
registry
:
if
input_type
in
registry
:
raise
TypeError
(
f
"
Dispatcher '
{
dispatcher
}
' already has a kernel registered for type '
{
input_type
}
'."
)
raise
TypeError
(
f
"
Functional '
{
functional
}
' already has a kernel registered for type '
{
input_type
}
'."
)
def
wrap
(
kernel
):
def
wrap
(
kernel
):
@
functools
.
wraps
(
kernel
)
@
functools
.
wraps
(
kernel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment