Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9d15f2e6
Unverified
Commit
9d15f2e6
authored
Nov 22, 2021
by
Philip Meier
Committed by
GitHub
Nov 22, 2021
Browse files
fix prototype transforms for non features (#4942)
Co-authored-by:
Francisco Massa
<
fvsmassa@gmail.com
>
parent
a96de03d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
4 deletions
+16
-4
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+11
-1
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+5
-3
No files found.
test/test_prototype_transforms.py
View file @
9d15f2e6
...
...
@@ -38,7 +38,7 @@ def test_feature_type_support():
[
transform_type
for
transform_type
in
TRANSFORM_TYPES
if
transform_type
is
not
transforms
.
Identity
],
ids
=
lambda
transform_type
:
transform_type
.
__name__
,
)
def
test_
no_op
(
transform_type
):
def
test_
feature_no_op_coverage
(
transform_type
):
unsupported_features
=
(
FEATURE_TYPES
-
transform_type
.
supported_feature_types
()
-
set
(
transform_type
.
NO_OP_FEATURE_TYPES
)
)
...
...
@@ -49,3 +49,13 @@ def test_no_op(transform_type):
f
"no-op for transform `
{
transform_type
.
__name__
}
`. Please either implement a feature transform for them, "
f
"or add them to the the `
{
transform_type
.
__name__
}
.NO_OP_FEATURE_TYPES` collection."
)
def
test_non_feature_no_op
():
class
TestTransform
(
transforms
.
Transform
):
@
staticmethod
def
image
(
input
):
return
input
no_op_sample
=
dict
(
int
=
0
,
float
=
0.0
,
bool
=
False
,
str
=
"str"
)
assert
TestTransform
()(
no_op_sample
)
==
no_op_sample
torchvision/prototype/transforms/_transform.py
View file @
9d15f2e6
...
...
@@ -351,14 +351,16 @@ class Transform(nn.Module):
sample: Sample.
params: Parameter dictionary ``params`` that will be passed to ``feature_transform(input, **params)``.
"""
if
isinstance
(
sample
,
collections
.
abc
.
Sequence
):
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if
isinstance
(
sample
,
collections
.
abc
.
Sequence
)
and
not
isinstance
(
sample
,
str
):
return
[
self
.
_transform_recursively
(
item
,
params
=
params
)
for
item
in
sample
]
elif
isinstance
(
sample
,
collections
.
abc
.
Mapping
):
return
{
name
:
self
.
_transform_recursively
(
item
,
params
=
params
)
for
name
,
item
in
sample
.
items
()}
else
:
feature_type
=
type
(
sample
)
if
not
self
.
supports
(
feature_type
):
if
feature_type
in
self
.
NO_OP_FEATURE_TYPES
:
if
not
issubclass
(
feature_type
,
features
.
Feature
)
or
feature_type
in
self
.
NO_OP_FEATURE_TYPES
:
return
sample
raise
TypeError
(
...
...
@@ -366,7 +368,7 @@ class Transform(nn.Module):
f
"If you want it to be a no-op, add the feature type to
{
type
(
self
).
__name__
}
.NO_OP_FEATURE_TYPES."
)
return
self
.
transform
(
sample
,
**
params
)
return
self
.
transform
(
cast
(
Union
[
torch
.
Tensor
,
features
.
Feature
],
sample
)
,
**
params
)
def
get_params
(
self
,
sample
:
Any
)
->
Dict
[
str
,
Any
]:
"""Returns the parameter dictionary used to transform the current sample.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment