Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
dd441e9d
Unverified
Commit
dd441e9d
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[perf] ShardedDDP & OSS, small improvements (#321)
* Couple of small improvements, no logic changes
parent
bd5d0496
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
27 deletions
+34
-27
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+24
-13
fairscale/optim/oss.py
fairscale/optim/oss.py
+10
-14
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
dd441e9d
...
...
@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module):
"""
logging
.
warning
(
"This is not useful anymore, gradients have been reduced automatically with the backward pass"
)
@
torch
.
no_grad
()
def
sync_buffers
(
self
,
blocking
:
bool
=
False
)
->
None
:
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
"""
with
torch
.
no_grad
():
work_handles
=
[
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
)
]
if
blocking
:
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
=
None
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
last_work_handle
=
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
if
blocking
and
last_work_handle
:
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle
.
wait
()
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
...
...
@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module):
yield
self
.
should_accumulate_grads
=
old_should_accumulate_grads
@
torch
.
no_grad
()
def
_clear_counters
(
self
)
->
None
:
"""Reset all the grad reduce and call counters"""
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
...
...
@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
@
torch
.
no_grad
()
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
...
...
@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module):
grad_acc
.
register_hook
(
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
,
sharded_optimizer
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this function in scope
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
"""
Sync the complete model states in between the ranks
"""
with
torch
.
no_grad
():
work_handles
=
[
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
for
t
in
self
.
module
.
state_dict
().
values
()
]
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
=
None
for
t
in
self
.
module
.
state_dict
().
values
():
last_work_handle
=
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
# Only wait for the last handle, they're inlined in the same CUDA stream
if
last_work_handle
:
last_work_handle
.
wait
()
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
...
...
fairscale/optim/oss.py
View file @
dd441e9d
...
...
@@ -534,6 +534,7 @@ class OSS(Optimizer):
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
return
global_rank
@
torch
.
no_grad
()
def
_sync_param_groups
(
self
,
local_to_global
:
bool
=
False
)
->
None
:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
...
...
@@ -548,10 +549,12 @@ class OSS(Optimizer):
elif
k
in
global_group
.
keys
():
local_group
[
k
]
=
global_group
[
k
]
@
torch
.
no_grad
()
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
i_param
=
0
last_work_handle
=
None
# Work handles are consumed within this scope, no callback
for
(
device
,
device_params
,)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
buckets
=
self
.
buckets
[
device
]
...
...
@@ -562,25 +565,18 @@ class OSS(Optimizer):
# Direct broadcasts only
for
param
in
params
:
if
not
self
.
should_bucket_param
[
i_param
]:
self
.
work_handles
.
append
(
Workhandle
(
handle
=
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
callback
=
None
,
)
last_work_handle
=
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
i_param
+=
1
# Bucket broadcasts
self
.
work_handles
.
append
(
Workhandle
(
handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
callback
=
None
,
)
)
last_work_handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
self
.
_consume_work_handles
()
# Only check on the last handle, they're all inlined on the same CUDA stream
if
last_work_handle
:
last_work_handle
.
wait
()
def
_consume_work_handles
(
self
)
->
None
:
"""Consume all the futures which are tied to this optimizer's buckets.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment