Unverified Commit cc7803c0 authored by Xuehai Pan's avatar Xuehai Pan Committed by GitHub
Browse files

Register ModelOutput as supported torch pytree nodes (#26618)

* Register ModelOutput as supported torch pytree nodes

* Test ModelOutput as supported torch pytree nodes

* Update type hints for pytree unflatten functions
parent ede051f1
...@@ -22,7 +22,7 @@ from collections.abc import MutableMapping ...@@ -22,7 +22,7 @@ from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
from typing import Any, ContextManager, List, Tuple from typing import Any, ContextManager, Iterable, List, Tuple
import numpy as np import numpy as np
...@@ -306,12 +306,10 @@ class ModelOutput(OrderedDict): ...@@ -306,12 +306,10 @@ class ModelOutput(OrderedDict):
`static_graph=True` with modules that output `ModelOutput` subclasses. `static_graph=True` with modules that output `ModelOutput` subclasses.
""" """
if is_torch_available(): if is_torch_available():
import torch.utils._pytree _torch_pytree._register_pytree_node(
torch.utils._pytree._register_pytree_node(
cls, cls,
torch.utils._pytree._dict_flatten, _model_output_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), _model_output_unflatten,
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -430,6 +428,23 @@ class ModelOutput(OrderedDict): ...@@ -430,6 +428,23 @@ class ModelOutput(OrderedDict):
return tuple(self[k] for k in self.keys()) return tuple(self[k] for k in self.keys())
if is_torch_available():
import torch.utils._pytree as _torch_pytree
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), (type(output), list(output.keys()))
def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
)
class ExplicitEnum(str, Enum): class ExplicitEnum(str, Enum):
""" """
Enum with more explicit error message for missing values. Enum with more explicit error message for missing values.
......
...@@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase): ...@@ -126,22 +126,24 @@ class ModelOutputTester(unittest.TestCase):
def test_torch_pytree(self): def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves) # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True # this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch import torch.utils._pytree as pytree
import torch.utils._pytree
x = ModelOutput({"a": 1.0, "c": 2.0})
self.assertFalse(pytree._is_leaf(x))
x = ModelOutputTest(a=1.0, c=2.0) x = ModelOutputTest(a=1.0, c=2.0)
self.assertFalse(torch.utils._pytree._is_leaf(x)) self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0] expected_flat_outs = [1.0, 2.0]
expected_tree_spec = torch.utils._pytree.TreeSpec( expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()] ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
) )
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x) actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs) self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec) self.assertEqual(expected_tree_spec, actual_tree_spec)
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x) self.assertEqual(x, unflattened_x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment