Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
3b494a13
Commit
3b494a13
authored
May 19, 2022
by
gongtao.vendor
Committed by
zhouzaida
Jul 19, 2022
Browse files
Support broadcasting all keys for TransformBroadcaster
parent
88f3cc3f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
5 deletions
+30
-5
mmcv/transforms/wrappers.py
mmcv/transforms/wrappers.py
+12
-5
tests/test_transforms/test_transforms_wrapper.py
tests/test_transforms/test_transforms_wrapper.py
+18
-0
No files found.
mmcv/transforms/wrappers.py
View file @
3b494a13
...
...
@@ -265,8 +265,9 @@ class KeyMapper(BaseTransform):
return
_map
(
data
,
remapping
)
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
inputs
=
self
.
map_input
(
results
,
self
.
mapping
)
inputs
=
results
if
self
.
mapping
:
inputs
=
self
.
map_input
(
inputs
,
self
.
mapping
)
outputs
=
self
.
transforms
(
inputs
)
if
self
.
remapping
:
...
...
@@ -368,8 +369,12 @@ class TransformBroadcaster(KeyMapper):
# infer split number from input
seq_len
=
None
key_rep
=
None
for
key
in
self
.
mapping
:
if
self
.
mapping
:
keys
=
self
.
mapping
.
keys
()
else
:
keys
=
data
.
keys
()
for
key
in
keys
:
assert
isinstance
(
data
[
key
],
Sequence
)
if
seq_len
is
not
None
:
if
len
(
data
[
key
])
!=
seq_len
:
...
...
@@ -383,14 +388,16 @@ class TransformBroadcaster(KeyMapper):
scatters
=
[]
for
i
in
range
(
seq_len
):
scatter
=
data
.
copy
()
for
key
in
self
.
mapping
:
for
key
in
keys
:
scatter
[
key
]
=
data
[
key
][
i
]
scatters
.
append
(
scatter
)
return
scatters
def
transform
(
self
,
results
:
Dict
):
# Apply input remapping
inputs
=
self
.
map_input
(
results
,
self
.
mapping
)
inputs
=
results
if
self
.
mapping
:
inputs
=
self
.
map_input
(
inputs
,
self
.
mapping
)
# Scatter sequential inputs into a list
inputs
=
self
.
scatter_sequence
(
inputs
)
...
...
tests/test_transforms/test_transforms_wrapper.py
View file @
3b494a13
...
...
@@ -138,6 +138,15 @@ def test_cache_random_parameters():
def
test_key_mapper
():
# Case 0: only remap
pipeline
=
KeyMapper
(
transforms
=
[
AddToValue
(
addend
=
1
)],
remapping
=
{
'value'
:
'v_out'
})
results
=
dict
(
value
=
0
)
results
=
pipeline
(
results
)
np
.
testing
.
assert_equal
(
results
[
'value'
],
0
)
# should be unchanged
np
.
testing
.
assert_equal
(
results
[
'v_out'
],
1
)
# Case 1: simple remap
pipeline
=
KeyMapper
(
...
...
@@ -313,6 +322,15 @@ def test_transform_broadcaster():
np
.
testing
.
assert_equal
(
results
[
'a'
],
3
)
np
.
testing
.
assert_equal
(
results
[
'b'
],
7
)
# Case 3: apply to all keys
pipeline
=
TransformBroadcaster
(
transforms
=
[
SumTwoValues
()],
mapping
=
None
,
remapping
=
None
)
results
=
dict
(
num_1
=
[
1
,
2
,
3
],
num_2
=
[
4
,
5
,
6
])
results
=
pipeline
(
results
)
np
.
testing
.
assert_equal
(
results
[
'sum'
],
[
5
,
7
,
9
])
# Case 4: inconsistent sequence length
with
pytest
.
raises
(
ValueError
):
pipeline
=
TransformBroadcaster
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment