Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
85962b97
Unverified
Commit
85962b97
authored
Apr 22, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 22, 2021
Browse files
[SDP] removing an assert which does not seem always accurate (#625)
parent
b0048b28
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
9 deletions
+15
-9
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+1
-3
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+4
-4
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+10
-2
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
85962b97
...
...
@@ -1664,9 +1664,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
if
recurse
:
return
not
isinstance
(
module
,
tuple
(
default_auto_wrap_policy
.
FORCE_LEAF_MODULES
))
# type: ignore
else
:
return
is_bn
and
not
isinstance
(
module
,
tuple
(
default_auto_wrap_policy
.
EXCLUDE_WRAP_MODULES
)
)
# type: ignore
return
is_bn
and
not
isinstance
(
module
,
tuple
(
default_auto_wrap_policy
.
EXCLUDE_WRAP_MODULES
))
# type: ignore
pg
=
None
if
single_rank_pg
:
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
85962b97
...
...
@@ -269,9 +269,9 @@ class ShardedDataParallel(nn.Module):
""" If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert
not
functools
.
reduce
(
lambda
x
,
y
:
x
or
y
,
self
.
_grad_to_be_reduced
,
False
)
,
(
"Grads waiting to be reduced: {}"
.
format
(
self
.
_grad_to_be_reduced
)
+
"
\n
If this is on purpose (grad accumulation), please use a no_sync() context"
if
functools
.
reduce
(
lambda
x
,
y
:
x
or
y
,
self
.
_grad_to_be_reduced
,
False
)
:
logging
.
warning
(
"Grads waiting to be reduced.
If this is on purpose (grad accumulation), please use a no_sync() context"
)
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
85962b97
...
...
@@ -262,9 +262,9 @@ def test_mixed_types():
dist
.
destroy_process_group
()
def
test_train_eval_change
():
def
run_
test_train_eval_change
(
rank
,
world_size
,
file
):
# Check that ShardedDDP handles the switch from training to eval properly
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
,
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
file
,
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
model
=
_get_mlp
()
model
.
train
()
...
...
@@ -288,6 +288,14 @@ def test_train_eval_change():
dist
.
destroy_process_group
()
def
test_train_eval_change
():
world_size
=
4
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
# Check that the wrapped module can change devices
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment