Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
dd441e9d
Unverified
Commit
dd441e9d
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[perf] ShardedDDP & OSS, small improvements (#321)
* Couple of small improvements, no logic changes
parent
bd5d0496
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
27 deletions
+34
-27
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+24
-13
fairscale/optim/oss.py
fairscale/optim/oss.py
+10
-14
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
dd441e9d
...
...
@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module):
"""
logging
.
warning
(
"This is not useful anymore, gradients have been reduced automatically with the backward pass"
)
@
torch
.
no_grad
()
def
sync_buffers
(
self
,
blocking
:
bool
=
False
)
->
None
:
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
"""
with
torch
.
no_grad
():
work_handles
=
[
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
)
]
if
blocking
:
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
=
None
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
last_work_handle
=
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
if
blocking
and
last_work_handle
:
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle
.
wait
()
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
...
...
@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module):
yield
self
.
should_accumulate_grads
=
old_should_accumulate_grads
@
torch
.
no_grad
()
def
_clear_counters
(
self
)
->
None
:
"""Reset all the grad reduce and call counters"""
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
...
...
@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
@
torch
.
no_grad
()
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
...
...
@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module):
grad_acc
.
register_hook
(
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
,
sharded_optimizer
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this function in scope
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
"""
Sync the complete model states in between the ranks
"""
with
torch
.
no_grad
():
work_handles
=
[
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
for
t
in
self
.
module
.
state_dict
().
values
()
]
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
=
None
for
t
in
self
.
module
.
state_dict
().
values
():
last_work_handle
=
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
# Only wait for the last handle, they're inlined in the same CUDA stream
if
last_work_handle
:
last_work_handle
.
wait
()
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
...
...
fairscale/optim/oss.py
View file @
dd441e9d
...
...
@@ -534,6 +534,7 @@ class OSS(Optimizer):
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
return
global_rank
@
torch
.
no_grad
()
def
_sync_param_groups
(
self
,
local_to_global
:
bool
=
False
)
->
None
:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
...
...
@@ -548,10 +549,12 @@ class OSS(Optimizer):
elif
k
in
global_group
.
keys
():
local_group
[
k
]
=
global_group
[
k
]
@
torch
.
no_grad
()
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
i_param
=
0
last_work_handle
=
None
# Work handles are consumed within this scope, no callback
for
(
device
,
device_params
,)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
buckets
=
self
.
buckets
[
device
]
...
...
@@ -562,25 +565,18 @@ class OSS(Optimizer):
# Direct broadcasts only
for
param
in
params
:
if
not
self
.
should_bucket_param
[
i_param
]:
self
.
work_handles
.
append
(
Workhandle
(
handle
=
dist
.
broadcast
(
last_work_handle
=
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
callback
=
None
,
)
)
i_param
+=
1
# Bucket broadcasts
self
.
work_handles
.
append
(
Workhandle
(
handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
callback
=
None
,
)
)
last_work_handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
self
.
_consume_work_handles
()
# Only check on the last handle, they're all inlined on the same CUDA stream
if
last_work_handle
:
last_work_handle
.
wait
()
def
_consume_work_handles
(
self
)
->
None
:
"""Consume all the futures which are tied to this optimizer's buckets.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment