Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9c4e6d1a
Unverified
Commit
9c4e6d1a
authored
Mar 09, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 09, 2021
Browse files
[fix] flaky SDP tests with Gloo, checking all handles (#499)
* seemingly fix flakyness for gloo by checking all coms handles
parent
8eaa3622
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
18 deletions
+23
-18
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+17
-12
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-6
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
9c4e6d1a
...
@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module):
# Communication related attributes
# Communication related attributes
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
reference_global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
reference_global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
...
@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module):
...
@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
blocking (bool): wait for the operation to conclude.
"""
"""
last_
work_handle
=
None
work_handle
s
=
[]
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
last_
work_handle
=
dist
.
broadcast
(
work_handle
s
.
append
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
)
)
if
blocking
and
last_work_handle
:
if
blocking
and
work_handles
:
# Only wait for the last coms, they're inlined on the same CUDA stream
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
last_work_handle
.
wait
()
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
work_handles
[
-
1
].
wait
()
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
r
"""Sets gradients of all model parameters to zero. See similar function
r
"""Sets gradients of all model parameters to zero. See similar function
...
@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module):
...
@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module):
Sync the complete model states in between the ranks
Sync the complete model states in between the ranks
"""
"""
last_
work_handle
=
None
work_handle
s
=
[]
for
t
in
self
.
module
.
state_dict
().
values
():
for
t
in
self
.
module
.
state_dict
().
values
():
last_
work_handle
=
dist
.
broadcast
(
work_handle
s
.
append
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
)
)
# Only wait for the last handle, they're inlined in the same CUDA stream
# gloo does not guarantee inlining like NCCL, wait for all requests
if
last_work_handle
:
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
last_work_handle
.
wait
()
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
elif
work_handles
:
work_handles
[
-
1
].
wait
()
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
"""
...
...
fairscale/optim/oss.py
View file @
9c4e6d1a
...
@@ -546,19 +546,19 @@ class OSS(Optimizer):
...
@@ -546,19 +546,19 @@ class OSS(Optimizer):
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
last_
work_handle
=
None
# Work handles are consumed within this scope, no callback
work_handle
s
=
[]
# Work handles are consumed within this scope, no callback
for
device
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
.
keys
():
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
if
bucket
.
numel
()
>
0
:
if
bucket
.
numel
()
>
0
:
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
work_handles
.
append
(
last_work_handle
=
dist
.
broadcast
(
dist
.
broadcast
(
tensor
=
bucket
,
src
=
self
.
_local_to_global_rank
[
src_rank
],
group
=
self
.
group
,
async_op
=
True
tensor
=
bucket
,
src
=
self
.
_local_to_global_rank
[
src_rank
],
group
=
self
.
group
,
async_op
=
True
)
)
)
# Only check on the last handle, they're all inlined on the same CUDA stream
# Only check on the last handle, they're all inlined on the same CUDA stream
if
last_work_handle
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
.
wait
()
def
_setup_flat_buffers
(
self
)
->
None
:
def
_setup_flat_buffers
(
self
)
->
None
:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment