Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
85dea5b2
Unverified
Commit
85dea5b2
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] SDP - adding the profiler labels (#630)
* adding the labels * longer labels, following aten::
parent
38ce54b7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
111 deletions
+117
-111
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+117
-111
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
85dea5b2
...
...
@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import
torch
from
torch
import
nn
from
torch.autograd
import
Variable
import
torch.autograd.profiler
as
profiler
import
torch.distributed
as
dist
from
fairscale.nn.misc
import
GradBucket
...
...
@@ -199,25 +200,26 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks.
"""
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
with
profiler
.
record_function
(
"fairscale::sdp::forward"
):
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
if
self
.
_auto_refresh_trainable
:
needs_setup
|=
self
.
_detect_train_change
()
if
self
.
_auto_refresh_trainable
:
needs_setup
|=
self
.
_detect_train_change
()
if
needs_setup
:
self
.
refresh_trainable
()
if
needs_setup
:
self
.
refresh_trainable
()
if
self
.
_enable_broadcast_buffers
:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self
.
sync_buffers
(
blocking
=
True
)
if
self
.
_enable_broadcast_buffers
:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self
.
sync_buffers
(
blocking
=
True
)
# Reset all the grad reduce and bucket state flags
self
.
_clear_counters
()
# Reset all the grad reduce and bucket state flags
self
.
_clear_counters
()
# Normal FW on the base model
return
self
.
_module
(
*
inputs
,
**
kwargs
)
# Normal FW on the base model
return
self
.
_module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
self
,
...
...
@@ -274,24 +276,25 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
with
profiler
.
record_function
(
"fairscale::sdp::refresh_trainable"
):
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_param_to_rank
=
{}
for
optim
in
self
.
_sharded_optimizers
:
# OSS may need to change the communication pattern
optim
.
refresh_trainable
()
self
.
_trainable_param_to_rank
=
{}
for
optim
in
self
.
_sharded_optimizers
:
# OSS may need to change the communication pattern
optim
.
refresh_trainable
()
# Update ShardedDDP given the new partitions
for
(
device_per_rank_params
)
in
optim
.
_per_device_params
.
values
():
# all the params on this device (inc all ranks)
for
device_params
in
device_per_rank_params
:
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
_param_to_rank
[
param
]
# Update ShardedDDP given the new partitions
for
(
device_per_rank_params
)
in
optim
.
_per_device_params
.
values
():
# all the params on this device (inc all ranks)
for
device_params
in
device_per_rank_params
:
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
_param_to_rank
[
param
]
self
.
_setup_bucket_strategy
()
self
.
_setup_backward_hooks
()
self
.
_setup_bucket_strategy
()
self
.
_setup_backward_hooks
()
def
reduce
(
self
)
->
None
:
"""
...
...
@@ -320,18 +323,19 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
"""
work_handles
=
[]
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
)
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
)
if
blocking
and
work_handles
:
if
self
.
_backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
work_handles
[
-
1
].
wait
()
if
blocking
and
work_handles
:
if
self
.
_backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
work_handles
[
-
1
].
wait
()
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
r
"""Sets gradients of all model parameters to zero. See similar function
...
...
@@ -480,39 +484,39 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
"""
# Detach possible pre-existing hooks
while
len
(
self
.
_grad_hooks
)
>
0
:
self
.
_grad_hooks
.
pop
().
remove
()
# Go through the parameters, attach the hook
self
.
_grad_accs
=
[]
self
.
_manual_reduce
=
[]
if
not
self
.
training
:
return
for
index
,
param
in
enumerate
(
self
.
_trainable_params
):
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
p_tmp
=
param
.
expand_as
(
param
)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if
p_tmp
.
grad_fn
is
not
None
:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
self
.
_manual_reduce
.
append
(
reduce_function
)
with
profiler
.
record_function
(
"fairscale::sdp::setup_backward_hooks"
):
# Detach possible pre-existing hooks
while
len
(
self
.
_grad_hooks
)
>
0
:
self
.
_grad_hooks
.
pop
().
remove
()
# Go through the parameters, attach the hook
self
.
_grad_accs
=
[]
self
.
_manual_reduce
=
[]
if
not
self
.
training
:
return
for
index
,
param
in
enumerate
(
self
.
_trainable_params
):
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
p_tmp
=
param
.
expand_as
(
param
)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if
p_tmp
.
grad_fn
is
not
None
:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
self
.
_manual_reduce
.
append
(
reduce_function
)
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
...
...
@@ -552,41 +556,42 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
if
not
self
.
_use_buckets
:
return
with
profiler
.
record_function
(
"fairscale::sdp::setup_buckets"
):
if
not
self
.
_use_buckets
:
return
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self
.
_buckets
=
{}
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self
.
_buckets
=
{}
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
device
=
param
.
device
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
device
=
param
.
device
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
if
param
.
device
not
in
self
.
_buckets
.
keys
():
self
.
_buckets
[
param
.
device
]
=
{}
if
param
.
device
not
in
self
.
_buckets
.
keys
():
self
.
_buckets
[
param
.
device
]
=
{}
if
dst_rank
not
in
self
.
_buckets
[
param
.
device
].
keys
():
self
.
_buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
_buffer_max_size
,
dtype
=
param
.
dtype
,
device
=
param
.
device
,
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
)
if
dst_rank
not
in
self
.
_buckets
[
param
.
device
].
keys
():
self
.
_buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
_buffer_max_size
,
dtype
=
param
.
dtype
,
device
=
param
.
device
,
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
)
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if
self
.
_buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
self
.
_buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_should_bucket_grad
[
i
]
=
True
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if
self
.
_buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
self
.
_buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_should_bucket_grad
[
i
]
=
True
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
_buckets
[
device
].
values
()
for
device
in
self
.
_buckets
.
keys
()]))
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
_buckets
[
device
].
values
()
for
device
in
self
.
_buckets
.
keys
()]))
# Resize the buckets to remove lost space in the end
for
bucket
in
self
.
_bucket_list
:
bucket
.
shrink
()
# Resize the buckets to remove lost space in the end
for
bucket
in
self
.
_bucket_list
:
bucket
.
shrink
()
def
_consume_work_handles
(
self
)
->
None
:
"""Consume all the futures which are tied to this optimizer's buckets.
...
...
@@ -628,19 +633,20 @@ class ShardedDataParallel(nn.Module):
self
.
_consume_work_handles
()
def
_detect_train_change
(
self
)
->
bool
:
# Optionally check whether the trainable parameters have changed
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
with
profiler
.
record_function
(
"fairscale::sdp::detect_train_changes"
):
# Optionally check whether the trainable parameters have changed
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - one or more parameters trainability changed
trainability_changed
=
trainable_mask
!=
self
.
_reference_trainable_mask
# - one or more parameters trainability changed
trainability_changed
=
trainable_mask
!=
self
.
_reference_trainable_mask
# - the whole model is not trainable but we still have grad hooks
trainability_changed
|=
not
self
.
training
and
len
(
self
.
_grad_hooks
)
>
0
# - the whole model is not trainable but we still have grad hooks
trainability_changed
|=
not
self
.
training
and
len
(
self
.
_grad_hooks
)
>
0
if
trainability_changed
:
logging
.
warning
(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
)
self
.
_reference_trainable_mask
=
trainable_mask
if
trainability_changed
:
logging
.
warning
(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
)
self
.
_reference_trainable_mask
=
trainable_mask
return
trainability_changed
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment