Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
ef7146d5
Unverified
Commit
ef7146d5
authored
Feb 18, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 18, 2021
Browse files
[fix][minor] ShardedDDP train/eval modes (#393)
* [fix] ShardedDDP train/eval modes * Update CHANGELOG.md
parent
47042917
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
1 deletion
+28
-1
CHANGELOG.md
CHANGELOG.md
+1
-0
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+1
-1
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+26
-0
No files found.
CHANGELOG.md
View file @
ef7146d5
...
...
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
-
ShardedDDP and OSS handle model trainability changes during training (
[
#369
](
https://github.com/facebookresearch/fairscale/issues/369
)
)
-
ShardedDDP state dict load/save bug (#386)
-
ShardedDDP handle train/eval modes (#393)
### Added
-
ShardedDDP manual reduce option for checkpointing (#389)
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
ef7146d5
...
...
@@ -352,7 +352,7 @@ class ShardedDataParallel(nn.Module):
assert
self
.
_bucket_list
is
not
None
for
bucket
in
self
.
_bucket_list
:
assert
self
.
should_accumulate_grads
or
bucket
.
sent
,
(
assert
not
self
.
training
or
self
.
should_accumulate_grads
or
bucket
.
sent
,
(
"A bucket failed to be sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
...
...
tests/nn/data_parallel/test_sharded_ddp.py
View file @
ef7146d5
...
...
@@ -440,6 +440,32 @@ def test_device_change():
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_training_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
group
=
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
)).
to
(
device
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
process_group
=
group
)
inputs
=
torch
.
rand
((
10
,
2
),
device
=
device
)
outputs
=
ddp_model
(
inputs
)
# assert if the module has not been changed properly
_
=
outputs
.
norm
().
backward
()
ddp_model
.
eval
()
ddp_model
(
inputs
)
# This will assert if eval() is not properly taken into account
ddp_model
(
inputs
)
dist
.
destroy_process_group
()
def
test_training_change
():
world_size
=
8
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cpu"
mp
.
spawn
(
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment