Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3e2547c3
Unverified
Commit
3e2547c3
authored
Jan 15, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 15, 2021
Browse files
[feat][ShardedDDP] Support the original module's attributes (#309)
* minor, but ease of life, one less papercut
parent
43a27cd4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
0 deletions
+25
-0
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+7
-0
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+18
-0
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
3e2547c3
...
@@ -162,6 +162,13 @@ class ShardedDataParallel(nn.Module):
...
@@ -162,6 +162,13 @@ class ShardedDataParallel(nn.Module):
if
blocking
:
if
blocking
:
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
"""A context manager to disable gradient synchronization."""
"""A context manager to disable gradient synchronization."""
...
...
tests/nn/data_parallel/test_sharded_ddp.py
View file @
3e2547c3
...
@@ -323,6 +323,24 @@ def test_ddp_attributes():
...
@@ -323,6 +323,24 @@ def test_ddp_attributes():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
test_random_attributes
():
# Check that ShardedDDP exposes the original module's attributes
url
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
dist
.
init_process_group
(
init_method
=
url
,
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
.
banana
=
"sweet"
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
)
assert
hasattr
(
ddp_model
,
"banana"
)
assert
not
hasattr
(
ddp_model
,
"orange"
)
dist
.
destroy_process_group
()
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment