Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
dd441e9d
Unverified
Commit
dd441e9d
authored
Jan 21, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 21, 2021
Browse files
[perf] ShardedDDP & OSS, small improvements (#321)
* Couple of small improvements, no logic changes
parent
bd5d0496
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
27 deletions
+34
-27
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+24
-13
fairscale/optim/oss.py
fairscale/optim/oss.py
+10
-14
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
dd441e9d
...
@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module):
...
@@ -149,18 +149,22 @@ class ShardedDataParallel(nn.Module):
"""
"""
logging
.
warning
(
"This is not useful anymore, gradients have been reduced automatically with the backward pass"
)
logging
.
warning
(
"This is not useful anymore, gradients have been reduced automatically with the backward pass"
)
@
torch
.
no_grad
()
def
sync_buffers
(
self
,
blocking
:
bool
=
False
)
->
None
:
def
sync_buffers
(
self
,
blocking
:
bool
=
False
)
->
None
:
"""
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
Sync all the param buffers in between ranks (including for instance batch norm statistics).
"""
"""
with
torch
.
no_grad
():
work_handles
=
[
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
)
]
if
blocking
:
last_work_handle
=
None
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
last_work_handle
=
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
if
blocking
and
last_work_handle
:
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle
.
wait
()
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
"""Forward missing attributes to wrapped module."""
"""Forward missing attributes to wrapped module."""
...
@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -177,6 +181,7 @@ class ShardedDataParallel(nn.Module):
yield
yield
self
.
should_accumulate_grads
=
old_should_accumulate_grads
self
.
should_accumulate_grads
=
old_should_accumulate_grads
@
torch
.
no_grad
()
def
_clear_counters
(
self
)
->
None
:
def
_clear_counters
(
self
)
->
None
:
"""Reset all the grad reduce and call counters"""
"""Reset all the grad reduce and call counters"""
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
...
@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -199,6 +204,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
Either way a delayed action is necessary and is passed as a callback.
"""
"""
@
torch
.
no_grad
()
def
reduce
(
*
_
:
Any
)
->
None
:
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
...
@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module):
...
@@ -262,17 +268,22 @@ class ShardedDataParallel(nn.Module):
grad_acc
.
register_hook
(
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
,
sharded_optimizer
))
grad_acc
.
register_hook
(
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
,
sharded_optimizer
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this function in scope
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this function in scope
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
def
_sync_params_and_buffers
(
self
)
->
None
:
"""
"""
Sync the complete model states in between the ranks
Sync the complete model states in between the ranks
"""
"""
with
torch
.
no_grad
():
work_handles
=
[
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
for
t
in
self
.
module
.
state_dict
().
values
()
]
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
last_work_handle
=
None
for
t
in
self
.
module
.
state_dict
().
values
():
last_work_handle
=
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
# Only wait for the last handle, they're inlined in the same CUDA stream
if
last_work_handle
:
last_work_handle
.
wait
()
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
"""
...
...
fairscale/optim/oss.py
View file @
dd441e9d
...
@@ -534,6 +534,7 @@ class OSS(Optimizer):
...
@@ -534,6 +534,7 @@ class OSS(Optimizer):
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
return
global_rank
return
global_rank
@
torch
.
no_grad
()
def
_sync_param_groups
(
self
,
local_to_global
:
bool
=
False
)
->
None
:
def
_sync_param_groups
(
self
,
local_to_global
:
bool
=
False
)
->
None
:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
If the global param groups have been altered, and we want to make sure that the
...
@@ -548,10 +549,12 @@ class OSS(Optimizer):
...
@@ -548,10 +549,12 @@ class OSS(Optimizer):
elif
k
in
global_group
.
keys
():
elif
k
in
global_group
.
keys
():
local_group
[
k
]
=
global_group
[
k
]
local_group
[
k
]
=
global_group
[
k
]
@
torch
.
no_grad
()
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
i_param
=
0
i_param
=
0
last_work_handle
=
None
# Work handles are consumed within this scope, no callback
for
(
device
,
device_params
,)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
for
(
device
,
device_params
,)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
buckets
=
self
.
buckets
[
device
]
buckets
=
self
.
buckets
[
device
]
...
@@ -562,25 +565,18 @@ class OSS(Optimizer):
...
@@ -562,25 +565,18 @@ class OSS(Optimizer):
# Direct broadcasts only
# Direct broadcasts only
for
param
in
params
:
for
param
in
params
:
if
not
self
.
should_bucket_param
[
i_param
]:
if
not
self
.
should_bucket_param
[
i_param
]:
self
.
work_handles
.
append
(
last_work_handle
=
dist
.
broadcast
(
Workhandle
(
handle
=
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
tensor
=
param
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
callback
=
None
,
)
)
)
i_param
+=
1
i_param
+=
1
# Bucket broadcasts
# Bucket broadcasts
self
.
work_handles
.
append
(
last_work_handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
Workhandle
(
handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
callback
=
None
,
)
)
self
.
_consume_work_handles
()
# Only check on the last handle, they're all inlined on the same CUDA stream
if
last_work_handle
:
last_work_handle
.
wait
()
def
_consume_work_handles
(
self
)
->
None
:
def
_consume_work_handles
(
self
)
->
None
:
"""Consume all the futures which are tied to this optimizer's buckets.
"""Consume all the futures which are tied to this optimizer's buckets.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment