Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
85962b97
Unverified
Commit
85962b97
authored
Apr 22, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 22, 2021
Browse files
[SDP] removing an assert which does not seem always accurate (#625)
parent
b0048b28
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
9 deletions
+15
-9
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+1
-3
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+4
-4
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+10
-2
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
85962b97
...
...
@@ -1664,9 +1664,7 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
if
recurse
:
return
not
isinstance
(
module
,
tuple
(
default_auto_wrap_policy
.
FORCE_LEAF_MODULES
))
# type: ignore
else
:
return
is_bn
and
not
isinstance
(
module
,
tuple
(
default_auto_wrap_policy
.
EXCLUDE_WRAP_MODULES
)
)
# type: ignore
return
is_bn
and
not
isinstance
(
module
,
tuple
(
default_auto_wrap_policy
.
EXCLUDE_WRAP_MODULES
))
# type: ignore
pg
=
None
if
single_rank_pg
:
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
85962b97
...
...
@@ -269,9 +269,9 @@ class ShardedDataParallel(nn.Module):
""" If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert
not
functools
.
reduce
(
lambda
x
,
y
:
x
or
y
,
self
.
_grad_to_be_reduced
,
False
)
,
(
"Grads waiting to be reduced: {}"
.
format
(
self
.
_grad_to_be_reduced
)
+
"
\n
If this is on purpose (grad accumulation), please use a no_sync() context"
if
functools
.
reduce
(
lambda
x
,
y
:
x
or
y
,
self
.
_grad_to_be_reduced
,
False
)
:
logging
.
warning
(
"Grads waiting to be reduced.
If this is on purpose (grad accumulation), please use a no_sync() context"
)
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
85962b97
...
...
@@ -262,9 +262,9 @@ def test_mixed_types():
dist
.
destroy_process_group
()
def
test_train_eval_change
():
def
run_
test_train_eval_change
(
rank
,
world_size
,
file
):
# Check that ShardedDDP handles the switch from training to eval properly
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
,
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
"file://"
+
file
,
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
model
=
_get_mlp
()
model
.
train
()
...
...
@@ -288,6 +288,14 @@ def test_train_eval_change():
dist
.
destroy_process_group
()
def
test_train_eval_change
():
world_size
=
4
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_train_eval_change
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
# Check that the wrapped module can change devices
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment