Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
e0e6f7e2
Unverified
Commit
e0e6f7e2
authored
Aug 16, 2023
by
Philip Meier
Committed by
GitHub
Aug 16, 2023
Browse files
allow dispatch to PIL image subclasses (#7835)
parent
c1592f96
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
25 deletions
+31
-25
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+23
-10
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+8
-15
No files found.
test/test_transforms_v2_refactored.py
View file @
e0e6f7e2
...
@@ -3,6 +3,7 @@ import decimal
...
@@ -3,6 +3,7 @@ import decimal
import
inspect
import
inspect
import
math
import
math
import
re
import
re
from
pathlib
import
Path
from
unittest
import
mock
from
unittest
import
mock
import
numpy
as
np
import
numpy
as
np
...
@@ -2126,16 +2127,10 @@ class TestGetKernel:
...
@@ -2126,16 +2127,10 @@ class TestGetKernel:
datapoints
.
Video
:
F
.
resize_video
,
datapoints
.
Video
:
F
.
resize_video
,
}
}
def
test_unsupported_types
(
self
):
@
pytest
.
mark
.
parametrize
(
"input_type"
,
[
str
,
int
,
object
])
class
MyTensor
(
torch
.
Tensor
):
def
test_unsupported_types
(
self
,
input_type
):
pass
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
_get_kernel
(
F
.
resize
,
input_type
)
class
MyPILImage
(
PIL
.
Image
.
Image
):
pass
for
input_type
in
[
str
,
int
,
object
,
MyTensor
,
MyPILImage
]:
with
pytest
.
raises
(
TypeError
,
match
=
"supports inputs of type"
):
_get_kernel
(
F
.
resize
,
input_type
)
def
test_exact_match
(
self
):
def
test_exact_match
(
self
):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
...
@@ -2197,6 +2192,24 @@ class TestGetKernel:
...
@@ -2197,6 +2192,24 @@ class TestGetKernel:
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
resize_my_datapoint
assert
_get_kernel
(
F
.
resize
,
MyDatapoint
)
is
resize_my_datapoint
def
test_pil_image_subclass
(
self
):
opened_image
=
PIL
.
Image
.
open
(
Path
(
__file__
).
parent
/
"assets"
/
"encode_jpeg"
/
"grace_hopper_517x606.jpg"
)
loaded_image
=
opened_image
.
convert
(
"RGB"
)
# check the assumptions
assert
isinstance
(
opened_image
,
PIL
.
Image
.
Image
)
assert
type
(
opened_image
)
is
not
PIL
.
Image
.
Image
assert
type
(
loaded_image
)
is
PIL
.
Image
.
Image
size
=
[
17
,
11
]
for
image
in
[
opened_image
,
loaded_image
]:
kernel
=
_get_kernel
(
F
.
resize
,
type
(
image
))
output
=
kernel
(
image
,
size
=
size
)
assert
F
.
get_size
(
output
)
==
size
class
TestPermuteChannels
:
class
TestPermuteChannels
:
_DEFAULT_PERMUTATION
=
[
2
,
0
,
1
]
_DEFAULT_PERMUTATION
=
[
2
,
0
,
1
]
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
e0e6f7e2
...
@@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
...
@@ -100,21 +100,14 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
if
not
registry
:
if
not
registry
:
raise
ValueError
(
f
"No kernel registered for functional
{
functional
.
__name__
}
."
)
raise
ValueError
(
f
"No kernel registered for functional
{
functional
.
__name__
}
."
)
# In case we have an exact type match, we take a shortcut.
for
cls
in
input_type
.
__mro__
:
if
input_type
in
registry
:
if
cls
in
registry
:
return
registry
[
input_type
]
return
registry
[
cls
]
elif
cls
is
datapoints
.
Datapoint
:
# In case of datapoints, we check if we have a kernel for a superclass registered
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
if
issubclass
(
input_type
,
datapoints
.
Datapoint
):
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# Since we have already checked for an exact match above, we can start the traversal at the superclass.
# allow kernels to be registered for datapoints.Datapoint anyway.
for
cls
in
input_type
.
__mro__
[
1
:]:
break
if
cls
is
datapoints
.
Datapoint
:
# We don't want user-defined datapoints to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at datapoints.Datapoint, since we don't
# allow kernels to be registered for datapoints.Datapoint anyway.
break
elif
cls
in
registry
:
return
registry
[
cls
]
if
allow_passthrough
:
if
allow_passthrough
:
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
return
lambda
inpt
,
*
args
,
**
kwargs
:
inpt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment