Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
f244e27e
Unverified
Commit
f244e27e
authored
Aug 15, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 15, 2023
Browse files
Dispatcher -> Functional (#7829)
parent
6ab8a96f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
99 additions
and
99 deletions
+99
-99
gallery/plot_custom_datapoints.py
gallery/plot_custom_datapoints.py
+3
-3
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+64
-64
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+2
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+4
-4
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+2
-2
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+1
-1
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+23
-23
No files found.
gallery/plot_custom_datapoints.py
View file @
f244e27e
...
...
@@ -49,7 +49,7 @@ my_dp
from
torchvision.transforms.v2
import
functional
as
F
@
F
.
register_kernel
(
dispatcher
=
"hflip"
,
datapoint_cls
=
MyDatapoint
)
@
F
.
register_kernel
(
functional
=
"hflip"
,
datapoint_cls
=
MyDatapoint
)
def
hflip_my_datapoint
(
my_dp
,
*
args
,
**
kwargs
):
print
(
"Flipping!"
)
out
=
my_dp
.
flip
(
-
1
)
...
...
@@ -64,9 +64,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# .. note::
#
# In our call to ``register_kernel`` above we used a string
# ``
dispatcher
="hflip"`` to refer to the functional we want to hook into. We
# ``
functional
="hflip"`` to refer to the functional we want to hook into. We
# could also have used the functional *itself*, i.e.
# ``@register_kernel(
dispatcher
=F.hflip, ...)``.
# ``@register_kernel(
functional
=F.hflip, ...)``.
#
# The functionals that you can be hooked into are the ones in
# ``torchvision.transforms.v2.functional`` and they are documented in
...
...
test/test_transforms_v2_refactored.py
View file @
f244e27e
...
...
@@ -163,25 +163,25 @@ def check_kernel(
_check_kernel_batched_vs_unbatched
(
kernel
,
input
,
*
args
,
**
kwargs
,
**
_to_tolerances
(
check_batched_vs_unbatched
))
def
_check_
dispatcher
_scripted_smoke
(
dispatcher
,
input
,
*
args
,
**
kwargs
):
"""Checks if the
dispatcher
can be scripted and the scripted version can be called without error."""
def
_check_
functional
_scripted_smoke
(
functional
,
input
,
*
args
,
**
kwargs
):
"""Checks if the
functional
can be scripted and the scripted version can be called without error."""
if
not
isinstance
(
input
,
datapoints
.
Image
):
return
dispatcher
_scripted
=
_script
(
dispatcher
)
functional
_scripted
=
_script
(
functional
)
with
ignore_jit_no_profile_information_warning
():
dispatcher
_scripted
(
input
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
functional
_scripted
(
input
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
def
check_
dispatcher
(
dispatcher
,
input
,
*
args
,
check_scripted_smoke
=
True
,
**
kwargs
):
def
check_
functional
(
functional
,
input
,
*
args
,
check_scripted_smoke
=
True
,
**
kwargs
):
unknown_input
=
object
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
unknown_input
)))):
dispatcher
(
unknown_input
,
*
args
,
**
kwargs
)
functional
(
unknown_input
,
*
args
,
**
kwargs
)
with
mock
.
patch
(
"torch._C._log_api_usage_once"
,
wraps
=
torch
.
_C
.
_log_api_usage_once
)
as
spy
:
output
=
dispatcher
(
input
,
*
args
,
**
kwargs
)
output
=
functional
(
input
,
*
args
,
**
kwargs
)
spy
.
assert_any_call
(
f
"
{
dispatcher
.
__module__
}
.
{
dispatcher
.
__name__
}
"
)
spy
.
assert_any_call
(
f
"
{
functional
.
__module__
}
.
{
functional
.
__name__
}
"
)
assert
isinstance
(
output
,
type
(
input
))
...
...
@@ -189,41 +189,41 @@ def check_dispatcher(dispatcher, input, *args, check_scripted_smoke=True, **kwar
assert
output
.
format
==
input
.
format
if
check_scripted_smoke
:
_check_
dispatcher
_scripted_smoke
(
dispatcher
,
input
,
*
args
,
**
kwargs
)
_check_
functional
_scripted_smoke
(
functional
,
input
,
*
args
,
**
kwargs
)
def
check_
dispatcher
_kernel_signature_match
(
dispatcher
,
*
,
kernel
,
input_type
):
"""Checks if the signature of the
dispatcher
matches the kernel signature."""
dispatcher
_params
=
list
(
inspect
.
signature
(
dispatcher
).
parameters
.
values
())[
1
:]
def
check_
functional
_kernel_signature_match
(
functional
,
*
,
kernel
,
input_type
):
"""Checks if the signature of the
functional
matches the kernel signature."""
functional
_params
=
list
(
inspect
.
signature
(
functional
).
parameters
.
values
())[
1
:]
kernel_params
=
list
(
inspect
.
signature
(
kernel
).
parameters
.
values
())[
1
:]
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
# We filter out metadata that is implicitly passed to the
dispatcher
through the input datapoint, but has to be
# We filter out metadata that is implicitly passed to the
functional
through the input datapoint, but has to be
# explicitly passed to the kernel.
explicit_metadata
=
{
datapoints
.
BoundingBoxes
:
{
"format"
,
"canvas_size"
},
}
kernel_params
=
[
param
for
param
in
kernel_params
if
param
.
name
not
in
explicit_metadata
.
get
(
input_type
,
set
())]
dispatcher
_params
=
iter
(
dispatcher
_params
)
for
dispatcher
_param
,
kernel_param
in
zip
(
dispatcher
_params
,
kernel_params
):
functional
_params
=
iter
(
functional
_params
)
for
functional
_param
,
kernel_param
in
zip
(
functional
_params
,
kernel_params
):
try
:
# In general, the
dispatcher
parameters are a superset of the kernel parameters. Thus, we filter out
#
dispatcher
parameters that have no kernel equivalent while keeping the order intact.
while
dispatcher
_param
.
name
!=
kernel_param
.
name
:
dispatcher
_param
=
next
(
dispatcher
_params
)
# In general, the
functional
parameters are a superset of the kernel parameters. Thus, we filter out
#
functional
parameters that have no kernel equivalent while keeping the order intact.
while
functional
_param
.
name
!=
kernel_param
.
name
:
functional
_param
=
next
(
functional
_params
)
except
StopIteration
:
raise
AssertionError
(
f
"Parameter `
{
kernel_param
.
name
}
` of kernel `
{
kernel
.
__name__
}
` "
f
"has no corresponding parameter on the
dispatcher `
{
dispatcher
.
__name__
}
`."
f
"has no corresponding parameter on the
functional `
{
functional
.
__name__
}
`."
)
from
None
if
issubclass
(
input_type
,
PIL
.
Image
.
Image
):
# PIL kernels often have more correct annotations, since they are not limited by JIT. Thus, we don't check
# them in the first place.
dispatcher
_param
.
_annotation
=
kernel_param
.
_annotation
=
inspect
.
Parameter
.
empty
functional
_param
.
_annotation
=
kernel_param
.
_annotation
=
inspect
.
Parameter
.
empty
assert
dispatcher
_param
==
kernel_param
assert
functional
_param
==
kernel_param
def
_check_transform_v1_compatibility
(
transform
,
input
):
...
...
@@ -482,8 +482,8 @@ class TestResize:
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
def
test_
dispatcher
(
self
,
size
,
make_input
):
check_
dispatcher
(
def
test_
functional
(
self
,
size
,
make_input
):
check_
functional
(
F
.
resize
,
make_input
(
self
.
INPUT_SIZE
),
size
=
size
,
...
...
@@ -502,8 +502,8 @@ class TestResize:
(
F
.
resize_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
resize
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
resize
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"size"
,
OUTPUT_SIZES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
...
...
@@ -608,7 +608,7 @@ class TestResize:
interpolation
=
interpolation
,
)
def
test_
dispatcher
_pil_antialias_warning
(
self
):
def
test_
functional
_pil_antialias_warning
(
self
):
with
pytest
.
warns
(
UserWarning
,
match
=
"Anti-alias option is always applied for PIL Image input"
):
F
.
resize
(
make_image_pil
(
self
.
INPUT_SIZE
),
size
=
self
.
OUTPUT_SIZES
[
0
],
antialias
=
False
)
...
...
@@ -763,8 +763,8 @@ class TestHorizontalFlip:
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
def
test_
dispatcher
(
self
,
make_input
):
check_
dispatcher
(
F
.
horizontal_flip
,
make_input
())
def
test_
functional
(
self
,
make_input
):
check_
functional
(
F
.
horizontal_flip
,
make_input
())
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
...
...
@@ -777,8 +777,8 @@ class TestHorizontalFlip:
(
F
.
horizontal_flip_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
horizontal_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
horizontal_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -939,8 +939,8 @@ class TestAffine:
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
def
test_
dispatcher
(
self
,
make_input
):
check_
dispatcher
(
F
.
affine
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
def
test_
functional
(
self
,
make_input
):
check_
functional
(
F
.
affine
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
...
...
@@ -953,8 +953,8 @@ class TestAffine:
(
F
.
affine_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
affine
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
affine
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1228,8 +1228,8 @@ class TestVerticalFlip:
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
def
test_
dispatcher
(
self
,
make_input
):
check_
dispatcher
(
F
.
vertical_flip
,
make_input
())
def
test_
functional
(
self
,
make_input
):
check_
functional
(
F
.
vertical_flip
,
make_input
())
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
...
...
@@ -1242,8 +1242,8 @@ class TestVerticalFlip:
(
F
.
vertical_flip_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
vertical_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
vertical_flip
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1378,8 +1378,8 @@ class TestRotate:
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_bounding_box
,
make_segmentation_mask
,
make_video
],
)
def
test_
dispatcher
(
self
,
make_input
):
check_
dispatcher
(
F
.
rotate
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
def
test_
functional
(
self
,
make_input
):
check_
functional
(
F
.
rotate
,
make_input
(),
**
self
.
_MINIMAL_AFFINE_KWARGS
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
...
...
@@ -1392,8 +1392,8 @@ class TestRotate:
(
F
.
rotate_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
rotate
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
rotate
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
...
...
@@ -1643,8 +1643,8 @@ class TestToDtype:
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
[
torch
.
float32
,
torch
.
float64
,
torch
.
uint8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"scale"
,
(
True
,
False
))
def
test_
dispatcher
(
self
,
make_input
,
input_dtype
,
output_dtype
,
device
,
scale
):
check_
dispatcher
(
def
test_
functional
(
self
,
make_input
,
input_dtype
,
output_dtype
,
device
,
scale
):
check_
functional
(
F
.
to_dtype
,
make_input
(
dtype
=
input_dtype
,
device
=
device
),
dtype
=
output_dtype
,
...
...
@@ -1810,8 +1810,8 @@ class TestAdjustBrightness:
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
def
test_
dispatcher
(
self
,
make_input
):
check_
dispatcher
(
F
.
adjust_brightness
,
make_input
(),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
def
test_
functional
(
self
,
make_input
):
check_
functional
(
F
.
adjust_brightness
,
make_input
(),
brightness_factor
=
self
.
_DEFAULT_BRIGHTNESS_FACTOR
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
...
...
@@ -1822,8 +1822,8 @@ class TestAdjustBrightness:
(
F
.
adjust_brightness_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
adjust_brightness
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
adjust_brightness
,
kernel
=
kernel
,
input_type
=
input_type
)
@
pytest
.
mark
.
parametrize
(
"brightness_factor"
,
_CORRECTNESS_BRIGHTNESS_FACTORS
)
def
test_image_correctness
(
self
,
brightness_factor
):
...
...
@@ -2042,7 +2042,7 @@ class TestShapeGetters:
assert
kernel
(
input
)
==
F
.
get_num_frames
(
input
)
==
num_frames
@
pytest
.
mark
.
parametrize
(
(
"
dispatcher
"
,
"make_input"
),
(
"
functional
"
,
"make_input"
),
[
(
F
.
get_dimensions
,
make_bounding_box
),
(
F
.
get_dimensions
,
make_detection_mask
),
...
...
@@ -2057,22 +2057,22 @@ class TestShapeGetters:
(
F
.
get_num_frames
,
make_segmentation_mask
),
],
)
def
test_unsupported_types
(
self
,
dispatcher
,
make_input
):
def
test_unsupported_types
(
self
,
functional
,
make_input
):
input
=
make_input
()
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
input
)))):
dispatcher
(
input
)
functional
(
input
)
class
TestRegisterKernel
:
@
pytest
.
mark
.
parametrize
(
"
dispatcher
"
,
(
F
.
resize
,
"resize"
))
def
test_register_kernel
(
self
,
dispatcher
):
@
pytest
.
mark
.
parametrize
(
"
functional
"
,
(
F
.
resize
,
"resize"
))
def
test_register_kernel
(
self
,
functional
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
kernel_was_called
=
False
@
F
.
register_kernel
(
dispatcher
,
CustomDatapoint
)
@
F
.
register_kernel
(
functional
,
CustomDatapoint
)
def
new_resize
(
dp
,
*
args
,
**
kwargs
):
nonlocal
kernel_was_called
kernel_was_called
=
True
...
...
@@ -2090,10 +2090,10 @@ class TestRegisterKernel:
t
(
datapoints
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
def
test_errors
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find
dispatcher
with name"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find
functional
with name"
):
F
.
register_kernel
(
"bad_name"
,
datapoints
.
Image
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on
dispatcher
s"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered on
functional
s"
):
F
.
register_kernel
(
datapoints
.
Image
,
F
.
resize
)
with
pytest
.
raises
(
ValueError
,
match
=
"Kernels can only be registered for subclasses"
):
...
...
@@ -2115,7 +2115,7 @@ class TestRegisterKernel:
class
TestGetKernel
:
# We are using F.resize as
dispatcher
and the kernels below as proxy. Any other
dispatcher
/ kernels combination
# We are using F.resize as
functional
and the kernels below as proxy. Any other
functional
/ kernels combination
# would also be fine
KERNELS
=
{
torch
.
Tensor
:
F
.
resize_image_tensor
,
...
...
@@ -2139,7 +2139,7 @@ class TestGetKernel:
def
test_exact_match
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
dispatcher
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
functional
# here, register the kernels without wrapper, and check the exact matching afterwards.
def
resize_with_pure_kernels
():
pass
...
...
@@ -2151,7 +2151,7 @@ class TestGetKernel:
def
test_builtin_datapoint_subclass
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
dispatcher
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize
functional
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# to the kernel of the corresponding superclass
def
resize_with_pure_kernels
():
...
...
@@ -2217,8 +2217,8 @@ class TestPermuteChannels:
check_kernel
(
kernel
,
make_input
(
dtype
=
dtype
,
device
=
device
),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
def
test_
dispatcher
(
self
,
make_input
):
check_
dispatcher
(
F
.
permute_channels
,
make_input
(),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
def
test_
functional
(
self
,
make_input
):
check_
functional
(
F
.
permute_channels
,
make_input
(),
permutation
=
self
.
_DEFAULT_PERMUTATION
)
@
pytest
.
mark
.
parametrize
(
(
"kernel"
,
"input_type"
),
...
...
@@ -2229,8 +2229,8 @@ class TestPermuteChannels:
(
F
.
permute_channels_video
,
datapoints
.
Video
),
],
)
def
test_
dispatcher
_signature
(
self
,
kernel
,
input_type
):
check_
dispatcher
_kernel_signature_match
(
F
.
permute_channels
,
kernel
=
kernel
,
input_type
=
input_type
)
def
test_
functional
_signature
(
self
,
kernel
,
input_type
):
check_
functional
_kernel_signature_match
(
F
.
permute_channels
,
kernel
=
kernel
,
input_type
=
input_type
)
def
reference_image_correctness
(
self
,
image
,
permutation
):
channel_images
=
image
.
split
(
1
,
dim
=-
3
)
...
...
torchvision/transforms/v2/_augment.py
View file @
f244e27e
...
...
@@ -91,13 +91,13 @@ class RandomErasing(_RandomApplyTransform):
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
img_c
,
img_h
,
img_w
=
query_chw
(
flat_inputs
)
...
...
torchvision/transforms/v2/_geometry.py
View file @
f244e27e
...
...
@@ -358,13 +358,13 @@ class FiveCrop(Transform):
super
().
__init__
()
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
self
.
_call_kernel
(
F
.
five_crop
,
inpt
,
self
.
size
)
...
...
@@ -405,13 +405,13 @@ class TenCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
self
.
vertical_flip
=
vertical_flip
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"datapoints.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
dispatcher
,
inpt
,
*
args
,
**
kwargs
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoints
.
BoundingBoxes
,
datapoints
.
Mask
):
...
...
torchvision/transforms/v2/_transform.py
View file @
f244e27e
...
...
@@ -30,8 +30,8 @@ class Transform(nn.Module):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
return
dict
()
def
_call_kernel
(
self
,
dispatcher
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
kernel
=
_get_kernel
(
dispatcher
,
type
(
inpt
),
allow_passthrough
=
True
)
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
kernel
=
_get_kernel
(
functional
,
type
(
inpt
),
allow_passthrough
=
True
)
return
kernel
(
inpt
,
*
args
,
**
kwargs
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
f244e27e
...
...
@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
)
->
torch
.
Tensor
:
# This being a kernel /
dispatcher
hybrid, we need an option to pass `old_format` explicitly for simple tensor
# This being a kernel /
functional
hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
f244e27e
...
...
@@ -12,7 +12,7 @@ def is_simple_tensor(inpt: Any) -> bool:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
# {
dispatcher
: {input_type: type_specific_kernel}}
# {
functional
: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY
:
Dict
[
Callable
,
Dict
[
Type
,
Callable
]]
=
{}
...
...
@@ -27,10 +27,10 @@ def _kernel_datapoint_wrapper(kernel):
return
wrapper
def
_register_kernel_internal
(
dispatcher
,
input_type
,
*
,
datapoint_wrapper
=
True
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
def
_register_kernel_internal
(
functional
,
input_type
,
*
,
datapoint_wrapper
=
True
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
functional
,
{})
if
input_type
in
registry
:
raise
ValueError
(
f
"
Dispatcher
{
dispatcher
}
already has a kernel registered for type
{
input_type
}
."
)
raise
ValueError
(
f
"
Functional
{
functional
}
already has a kernel registered for type
{
input_type
}
."
)
def
decorator
(
kernel
):
registry
[
input_type
]
=
(
...
...
@@ -43,14 +43,14 @@ def _register_kernel_internal(dispatcher, input_type, *, datapoint_wrapper=True)
return
decorator
def
_name_to_
dispatcher
(
name
):
def
_name_to_
functional
(
name
):
import
torchvision.transforms.v2.functional
# noqa
try
:
return
getattr
(
torchvision
.
transforms
.
v2
.
functional
,
name
)
except
AttributeError
:
raise
ValueError
(
f
"Could not find
dispatcher
with name '
{
name
}
' in torchvision.transforms.v2.functional."
f
"Could not find
functional
with name '
{
name
}
' in torchvision.transforms.v2.functional."
)
from
None
...
...
@@ -59,21 +59,21 @@ _BUILTIN_DATAPOINT_TYPES = {
}
def
register_kernel
(
dispatcher
,
datapoint_cls
):
"""Decorate a kernel to register it for a
dispatcher
and a (custom) datapoint type.
def
register_kernel
(
functional
,
datapoint_cls
):
"""Decorate a kernel to register it for a
functional
and a (custom) datapoint type.
See :ref:`sphx_glr_auto_examples_plot_custom_datapoints.py` for usage
details.
"""
if
isinstance
(
dispatcher
,
str
):
dispatcher
=
_name_to_
dispatcher
(
name
=
dispatcher
)
if
isinstance
(
functional
,
str
):
functional
=
_name_to_
functional
(
name
=
functional
)
elif
not
(
callable
(
dispatcher
)
and
getattr
(
dispatcher
,
"__module__"
,
""
).
startswith
(
"torchvision.transforms.v2.functional"
)
callable
(
functional
)
and
getattr
(
functional
,
"__module__"
,
""
).
startswith
(
"torchvision.transforms.v2.functional"
)
):
raise
ValueError
(
f
"Kernels can only be registered on
dispatcher
s from the torchvision.transforms.v2.functional namespace, "
f
"but got
{
dispatcher
}
."
f
"Kernels can only be registered on
functional
s from the torchvision.transforms.v2.functional namespace, "
f
"but got
{
functional
}
."
)
if
not
(
isinstance
(
datapoint_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)):
...
...
@@ -85,13 +85,13 @@ def register_kernel(dispatcher, datapoint_cls):
if
datapoint_cls
in
_BUILTIN_DATAPOINT_TYPES
:
raise
ValueError
(
f
"Kernels cannot be registered for the builtin datapoint classes, but got
{
datapoint_cls
}
"
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
return
_register_kernel_internal
(
functional
,
datapoint_cls
,
datapoint_wrapper
=
False
)
def
_get_kernel
(
dispatcher
,
input_type
,
*
,
allow_passthrough
=
False
):
registry
=
_KERNEL_REGISTRY
.
get
(
dispatcher
)
def
_get_kernel
(
functional
,
input_type
,
*
,
allow_passthrough
=
False
):
registry
=
_KERNEL_REGISTRY
.
get
(
functional
)
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for
dispatcher
{
dispatcher
.
__name__
}
."
)
raise
ValueError
(
f
"No kernel registered for
functional
{
functional
.
__name__
}
."
)
# In case we have an exact type match, we take a shortcut.
if
input_type
in
registry
:
...
...
@@ -113,17 +113,17 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
raise
TypeError
(
f
"
Dispatcher F.
{
dispatcher
.
__name__
}
supports inputs of type
{
registry
.
keys
()
}
, "
f
"
Functional F.
{
functional
.
__name__
}
supports inputs of type
{
registry
.
keys
()
}
, "
f
"but got
{
input_type
}
instead."
)
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary
dispatcher
s rather than wrap_kernel: bool
def
_register_five_ten_crop_kernel_internal
(
dispatcher
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
dispatcher
,
{})
# We could get rid of this by letting _register_kernel_internal take arbitrary
functional
s rather than wrap_kernel: bool
def
_register_five_ten_crop_kernel_internal
(
functional
,
input_type
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
functional
,
{})
if
input_type
in
registry
:
raise
TypeError
(
f
"
Dispatcher '
{
dispatcher
}
' already has a kernel registered for type '
{
input_type
}
'."
)
raise
TypeError
(
f
"
Functional '
{
functional
}
' already has a kernel registered for type '
{
input_type
}
'."
)
def
wrap
(
kernel
):
@
functools
.
wraps
(
kernel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment