Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
eeb6684e
Unverified
Commit
eeb6684e
authored
Sep 12, 2022
by
Min Xu
Committed by
GitHub
Sep 12, 2022
Browse files
[feat] support namedtuple in container.py (#1069)
parent
73bf5964
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
5 deletions
+22
-5
fairscale/internal/containers.py
fairscale/internal/containers.py
+12
-3
tests/utils/test_containers.py
tests/utils/test_containers.py
+10
-2
No files found.
fairscale/internal/containers.py
View file @
eeb6684e
...
...
@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from
collections
import
OrderedDict
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
,
cast
import
numpy
as
np
import
torch
...
...
@@ -14,7 +14,7 @@ from torch.nn.utils.rnn import PackedSequence
def
apply_to_type
(
type_fn
:
Callable
,
fn
:
Callable
,
container
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
Dict
,
List
,
Tuple
,
Set
]
type_fn
:
Callable
,
fn
:
Callable
,
container
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
Dict
,
List
,
Tuple
,
Set
,
NamedTuple
]
)
->
Any
:
"""Recursively apply to all objects in different kinds of container types that matches a type function."""
...
...
@@ -34,7 +34,16 @@ def apply_to_type(
elif
isinstance
(
x
,
list
):
return
[
_apply
(
x
)
for
x
in
x
]
elif
isinstance
(
x
,
tuple
):
return
tuple
(
_apply
(
x
)
for
x
in
x
)
f
=
getattr
(
x
,
"_fields"
,
None
)
if
f
is
None
:
return
tuple
(
_apply
(
x
)
for
x
in
x
)
else
:
assert
isinstance
(
f
,
tuple
),
"This needs to be a namedtuple"
# convert the namedtuple to a dict and _apply().
x
=
cast
(
NamedTuple
,
x
)
_dict
:
Dict
[
str
,
Any
]
=
x
.
_asdict
()
_dict
=
{
key
:
_apply
(
value
)
for
key
,
value
in
_dict
.
items
()}
return
type
(
x
)(
**
_dict
)
# make a copy of the namedtuple
elif
isinstance
(
x
,
set
):
return
{
_apply
(
x
)
for
x
in
x
}
else
:
...
...
tests/utils/test_containers.py
View file @
eeb6684e
...
...
@@ -9,7 +9,7 @@
""" Test utility classes from containers.py. """
from
collections
import
OrderedDict
from
collections
import
OrderedDict
,
namedtuple
import
random
import
pytest
...
...
@@ -42,13 +42,21 @@ def test_apply_to_tensors(devices):
return
t
# create a mixed bag of data.
data
=
[
1
,
"str"
]
data
=
[
1
,
"str"
]
# list
# dict
data
.
append
({
"key1"
:
get_a_tensor
(),
"key2"
:
{
1
:
get_a_tensor
()},
"key3"
:
3
})
# set
data
.
insert
(
0
,
set
([
"x"
,
get_a_tensor
(),
get_a_tensor
()]))
# tuple
data
.
append
(([
1
],
get_a_tensor
(),
1
,
[
get_a_tensor
()],
set
((
1
,
2
))))
# OrderedDict
od
=
OrderedDict
()
od
[
"k"
]
=
"value"
data
.
append
(
od
)
# namedtuple
NT
=
namedtuple
(
"NT"
,
[
"key1"
,
"key2"
])
nt
=
NT
(
key1
=
1
,
key2
=
get_a_tensor
())
data
.
append
(
nt
)
total
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment