Unverified Commit eeb6684e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] support namedtuple in container.py (#1069)

parent 73bf5964
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast
import numpy as np
import torch
......@@ -14,7 +14,7 @@ from torch.nn.utils.rnn import PackedSequence
def apply_to_type(
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set, NamedTuple]
) -> Any:
"""Recursively apply to all objects in different kinds of container types that matches a type function."""
......@@ -34,7 +34,16 @@ def apply_to_type(
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
f = getattr(x, "_fields", None)
if f is None:
return tuple(_apply(x) for x in x)
else:
assert isinstance(f, tuple), "This needs to be a namedtuple"
# convert the namedtuple to a dict and _apply().
x = cast(NamedTuple, x)
_dict: Dict[str, Any] = x._asdict()
_dict = {key: _apply(value) for key, value in _dict.items()}
return type(x)(**_dict) # make a copy of the namedtuple
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
......
......@@ -9,7 +9,7 @@
""" Test utility classes from containers.py. """
from collections import OrderedDict
from collections import OrderedDict, namedtuple
import random
import pytest
......@@ -42,13 +42,21 @@ def test_apply_to_tensors(devices):
return t
# create a mixed bag of data.
data = [1, "str"]
data = [1, "str"] # list
# dict
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
# set
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
# tuple
data.append(([1], get_a_tensor(), 1, [get_a_tensor()], set((1, 2))))
# OrderedDict
od = OrderedDict()
od["k"] = "value"
data.append(od)
# namedtuple
NT = namedtuple("NT", ["key1", "key2"])
nt = NT(key1=1, key2=get_a_tensor())
data.append(nt)
total = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment