Unverified Commit c7617e48 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

Register BaseOutput subclasses as supported torch.utils._pytree nodes (#5459)



* Register BaseOutput subclasses as supported torch.utils._pytree nodes

* lint

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 77241c48
...@@ -51,6 +51,21 @@ class BaseOutput(OrderedDict): ...@@ -51,6 +51,21 @@ class BaseOutput(OrderedDict):
</Tip> </Tip>
""" """
def __init_subclass__(cls) -> None:
"""Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
import torch.utils._pytree
torch.utils._pytree._register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)
def __post_init__(self): def __post_init__(self):
class_fields = fields(self) class_fields = fields(self)
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import PIL.Image import PIL.Image
from diffusers.utils.outputs import BaseOutput from diffusers.utils.outputs import BaseOutput
from diffusers.utils.testing_utils import require_torch
@dataclass @dataclass
...@@ -69,3 +70,24 @@ class ConfigTester(unittest.TestCase): ...@@ -69,3 +70,24 @@ class ConfigTester(unittest.TestCase):
assert dir(outputs_orig) == dir(outputs_copy) assert dir(outputs_orig) == dir(outputs_copy)
assert dict(outputs_orig) == dict(outputs_copy) assert dict(outputs_orig) == dict(outputs_copy)
assert vars(outputs_orig) == vars(outputs_copy) assert vars(outputs_orig) == vars(outputs_copy)
@require_torch
def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch
import torch.utils._pytree
data = np.random.rand(1, 3, 4, 4)
x = CustomOutput(images=data)
self.assertFalse(torch.utils._pytree._is_leaf(x))
expected_flat_outs = [data]
expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()])
actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec)
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment