Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9ebf10af
Unverified
Commit
9ebf10af
authored
Aug 03, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 03, 2023
Browse files
Allow register_kernel() to take dispatcher name as input (#7796)
parent
f3c89cc6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
0 deletions
+44
-0
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+33
-0
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+11
-0
No files found.
test/test_transforms_v2_refactored.py
View file @
9ebf10af
...
@@ -2181,3 +2181,36 @@ class TestShapeGetters:
...
@@ -2181,3 +2181,36 @@ class TestShapeGetters:
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
input
)))):
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
str
(
type
(
input
)))):
dispatcher
(
input
)
dispatcher
(
input
)
class
TestRegisterKernel
:
@
pytest
.
mark
.
parametrize
(
"dispatcher"
,
(
F
.
resize
,
"resize"
))
def
test_register_kernel
(
self
,
dispatcher
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
kernel_was_called
=
False
@
F
.
register_kernel
(
dispatcher
,
CustomDatapoint
)
def
new_resize
(
dp
,
*
args
,
**
kwargs
):
nonlocal
kernel_was_called
kernel_was_called
=
True
return
dp
t
=
transforms
.
Resize
(
size
=
(
224
,
224
),
antialias
=
True
)
my_dp
=
CustomDatapoint
(
torch
.
rand
(
3
,
10
,
10
))
out
=
t
(
my_dp
)
assert
out
is
my_dp
assert
kernel_was_called
# Sanity check to make sure we didn't override the kernel of other types
t
(
torch
.
rand
(
3
,
10
,
10
)).
shape
==
(
3
,
224
,
224
)
t
(
datapoints
.
Image
(
torch
.
rand
(
3
,
10
,
10
))).
shape
==
(
3
,
224
,
224
)
def
test_bad_disaptcher_name
(
self
):
class
CustomDatapoint
(
datapoints
.
Datapoint
):
pass
with
pytest
.
raises
(
ValueError
,
match
=
"Could not find dispatcher with name"
):
F
.
register_kernel
(
"bad_name"
,
CustomDatapoint
)
torchvision/transforms/v2/functional/_utils.py
View file @
9ebf10af
...
@@ -37,7 +37,18 @@ def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=Tr
...
@@ -37,7 +37,18 @@ def _register_kernel_internal(dispatcher, datapoint_cls, *, datapoint_wrapper=Tr
return
decorator
return
decorator
def
_name_to_dispatcher
(
name
):
import
torchvision.transforms.v2.functional
# noqa
try
:
return
getattr
(
torchvision
.
transforms
.
v2
.
functional
,
name
)
except
AttributeError
:
raise
ValueError
(
f
"Could not find dispatcher with name '
{
name
}
'."
)
from
None
def
register_kernel
(
dispatcher
,
datapoint_cls
):
def
register_kernel
(
dispatcher
,
datapoint_cls
):
if
isinstance
(
dispatcher
,
str
):
dispatcher
=
_name_to_dispatcher
(
name
=
dispatcher
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
return
_register_kernel_internal
(
dispatcher
,
datapoint_cls
,
datapoint_wrapper
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment