Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9c4e6d1a
Unverified
Commit
9c4e6d1a
authored
Mar 09, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 09, 2021
Browse files
[fix] flaky SDP tests with Gloo, checking all handles (#499)
* seemingly fix flakyness for gloo by checking all coms handles
parent
8eaa3622
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
18 deletions
+23
-18
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+17
-12
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-6
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
9c4e6d1a
...
@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module):
# Communication related attributes
# Communication related attributes
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
reference_global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
reference_global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
...
@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module):
...
@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
blocking (bool): wait for the operation to conclude.
"""
"""
last_
work_handle
=
None
work_handle
s
=
[]
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
last_
work_handle
=
dist
.
broadcast
(
work_handle
s
.
append
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
)
)
if
blocking
and
last_work_handle
:
if
blocking
and
work_handles
:
# Only wait for the last coms, they're inlined on the same CUDA stream
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
last_work_handle
.
wait
()
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
work_handles
[
-
1
].
wait
()
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
r
"""Sets gradients of all model parameters to zero. See similar function
r
"""Sets gradients of all model parameters to zero. See similar function
...
@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module):
...
@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module):
Sync the complete model states in between the ranks
Sync the complete model states in between the ranks
"""
"""
last_
work_handle
=
None
work_handle
s
=
[]
for
t
in
self
.
module
.
state_dict
().
values
():
for
t
in
self
.
module
.
state_dict
().
values
():
last_
work_handle
=
dist
.
broadcast
(
work_handle
s
.
append
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
)
)
# Only wait for the last handle, they're inlined in the same CUDA stream
# gloo does not guarantee inlining like NCCL, wait for all requests
if
last_work_handle
:
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
last_work_handle
.
wait
()
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
elif
work_handles
:
work_handles
[
-
1
].
wait
()
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
"""
...
...
fairscale/optim/oss.py
View file @
9c4e6d1a
...
@@ -546,19 +546,19 @@ class OSS(Optimizer):
...
@@ -546,19 +546,19 @@ class OSS(Optimizer):
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
last_
work_handle
=
None
# Work handles are consumed within this scope, no callback
work_handle
s
=
[]
# Work handles are consumed within this scope, no callback
for
device
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
.
keys
():
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
if
bucket
.
numel
()
>
0
:
if
bucket
.
numel
()
>
0
:
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
work_handles
.
append
(
last_work_handle
=
dist
.
broadcast
(
dist
.
broadcast
(
tensor
=
bucket
,
src
=
self
.
_local_to_global_rank
[
src_rank
],
group
=
self
.
group
,
async_op
=
True
tensor
=
bucket
,
src
=
self
.
_local_to_global_rank
[
src_rank
],
group
=
self
.
group
,
async_op
=
True
)
)
)
# Only check on the last handle, they're all inlined on the same CUDA stream
# Only check on the last handle, they're all inlined on the same CUDA stream
if
last_work_handle
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
.
wait
()
def
_setup_flat_buffers
(
self
)
->
None
:
def
_setup_flat_buffers
(
self
)
->
None
:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment