Unverified Commit 3e2547c3 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][ShardedDDP] Support the original module's attributes (#309)

* minor, but ease of life, one less papercut
parent 43a27cd4
...@@ -162,6 +162,13 @@ class ShardedDataParallel(nn.Module): ...@@ -162,6 +162,13 @@ class ShardedDataParallel(nn.Module):
if blocking: if blocking:
_ = list(map(lambda x: x.wait(), work_handles)) _ = list(map(lambda x: x.wait(), work_handles))
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
@contextlib.contextmanager @contextlib.contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization.""" """A context manager to disable gradient synchronization."""
......
...@@ -323,6 +323,24 @@ def test_ddp_attributes(): ...@@ -323,6 +323,24 @@ def test_ddp_attributes():
dist.destroy_process_group() dist.destroy_process_group()
def test_random_attributes():
# Check that ShardedDDP exposes the original module's attributes
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet"
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "banana")
assert not hasattr(ddp_model, "orange")
dist.destroy_process_group()
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name): def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment