Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
ef7146d5
Unverified
Commit
ef7146d5
authored
Feb 18, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 18, 2021
Browse files
[fix][minor] ShardedDDP train/eval modes (#393)
* [fix] ShardedDDP train/eval modes * Update CHANGELOG.md
parent
47042917
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
1 deletion
+28
-1
CHANGELOG.md
CHANGELOG.md
+1
-0
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+1
-1
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+26
-0
No files found.
CHANGELOG.md
View file @
ef7146d5
...
...
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
-
ShardedDDP and OSS handle model trainability changes during training (
[
#369
](
https://github.com/facebookresearch/fairscale/issues/369
)
)
-
ShardedDDP state dict load/save bug (#386)
-
ShardedDDP handle train/eval modes (#393)
### Added
-
ShardedDDP manual reduce option for checkpointing (#389)
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
ef7146d5
...
...
@@ -352,7 +352,7 @@ class ShardedDataParallel(nn.Module):
assert
self
.
_bucket_list
is
not
None
for
bucket
in
self
.
_bucket_list
:
assert
self
.
should_accumulate_grads
or
bucket
.
sent
,
(
assert
not
self
.
training
or
self
.
should_accumulate_grads
or
bucket
.
sent
,
(
"A bucket failed to be sent, probably unused parameters."
+
"Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
...
...
tests/nn/data_parallel/test_sharded_ddp.py
View file @
ef7146d5
...
...
@@ -440,6 +440,32 @@ def test_device_change():
mp
.
spawn
(
run_test_device_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_training_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
group
=
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
)).
to
(
device
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
ddp_model
=
ShardedDataParallel
(
model
,
optimizer
,
process_group
=
group
)
inputs
=
torch
.
rand
((
10
,
2
),
device
=
device
)
outputs
=
ddp_model
(
inputs
)
# assert if the module has not been changed properly
_
=
outputs
.
norm
().
backward
()
ddp_model
.
eval
()
ddp_model
(
inputs
)
# This will assert if eval() is not properly taken into account
ddp_model
(
inputs
)
dist
.
destroy_process_group
()
def
test_training_change
():
world_size
=
8
backend
=
"gloo"
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cpu"
mp
.
spawn
(
run_test_training_change
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_test_ddp_sync_batch_norm
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
):
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment