Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
85dea5b2
Unverified
Commit
85dea5b2
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] SDP - adding the profiler labels (#630)
* adding the labels * longer labels, following aten::
parent
38ce54b7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
111 deletions
+117
-111
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+117
-111
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
85dea5b2
...
@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
...
@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
import
torch.autograd.profiler
as
profiler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
fairscale.nn.misc
import
GradBucket
from
fairscale.nn.misc
import
GradBucket
...
@@ -199,25 +200,26 @@ class ShardedDataParallel(nn.Module):
...
@@ -199,25 +200,26 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks.
backward pass for gradient reduction to the proper ranks.
"""
"""
# Deferred initialization, or change detection
with
profiler
.
record_function
(
"fairscale::sdp::forward"
):
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
if
self
.
_auto_refresh_trainable
:
if
self
.
_auto_refresh_trainable
:
needs_setup
|=
self
.
_detect_train_change
()
needs_setup
|=
self
.
_detect_train_change
()
if
needs_setup
:
if
needs_setup
:
self
.
refresh_trainable
()
self
.
refresh_trainable
()
if
self
.
_enable_broadcast_buffers
:
if
self
.
_enable_broadcast_buffers
:
# NCCL communications are on a different stream, needs to be blocking
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
# for the subsequent FW to be correct
self
.
sync_buffers
(
blocking
=
True
)
self
.
sync_buffers
(
blocking
=
True
)
# Reset all the grad reduce and bucket state flags
# Reset all the grad reduce and bucket state flags
self
.
_clear_counters
()
self
.
_clear_counters
()
# Normal FW on the base model
# Normal FW on the base model
return
self
.
_module
(
*
inputs
,
**
kwargs
)
return
self
.
_module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
def
to
(
# type: ignore
self
,
self
,
...
@@ -274,24 +276,25 @@ class ShardedDataParallel(nn.Module):
...
@@ -274,24 +276,25 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
)
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
with
profiler
.
record_function
(
"fairscale::sdp::refresh_trainable"
):
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_param_to_rank
=
{}
self
.
_trainable_param_to_rank
=
{}
for
optim
in
self
.
_sharded_optimizers
:
for
optim
in
self
.
_sharded_optimizers
:
# OSS may need to change the communication pattern
# OSS may need to change the communication pattern
optim
.
refresh_trainable
()
optim
.
refresh_trainable
()
# Update ShardedDDP given the new partitions
# Update ShardedDDP given the new partitions
for
(
for
(
device_per_rank_params
device_per_rank_params
)
in
optim
.
_per_device_params
.
values
():
# all the params on this device (inc all ranks)
)
in
optim
.
_per_device_params
.
values
():
# all the params on this device (inc all ranks)
for
device_params
in
device_per_rank_params
:
for
device_params
in
device_per_rank_params
:
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
_param_to_rank
[
param
]
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
_param_to_rank
[
param
]
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
self
.
_setup_backward_hooks
()
self
.
_setup_backward_hooks
()
def
reduce
(
self
)
->
None
:
def
reduce
(
self
)
->
None
:
"""
"""
...
@@ -320,18 +323,19 @@ class ShardedDataParallel(nn.Module):
...
@@ -320,18 +323,19 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
blocking (bool): wait for the operation to conclude.
"""
"""
work_handles
=
[]
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
dist
.
broadcast
(
buffer
.
data
,
self
.
_reference_global_rank
,
self
.
_process_group
,
async_op
=
True
)
)
)
if
blocking
and
work_handles
:
if
blocking
and
work_handles
:
if
self
.
_backend
!=
dist
.
Backend
.
NCCL
:
if
self
.
_backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
else
:
work_handles
[
-
1
].
wait
()
work_handles
[
-
1
].
wait
()
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
def
zero_grad
(
self
,
set_to_none
:
bool
=
False
)
->
None
:
r
"""Sets gradients of all model parameters to zero. See similar function
r
"""Sets gradients of all model parameters to zero. See similar function
...
@@ -480,39 +484,39 @@ class ShardedDataParallel(nn.Module):
...
@@ -480,39 +484,39 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter.
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
This makes the gradient reduction automatic whenever there's a backward pass
"""
"""
with
profiler
.
record_function
(
"fairscale::sdp::setup_backward_hooks"
):
# Detach possible pre-existing hooks
# Detach possible pre-existing hooks
while
len
(
self
.
_grad_hooks
)
>
0
:
while
len
(
self
.
_grad_hooks
)
>
0
:
self
.
_grad_hooks
.
pop
().
remove
()
self
.
_grad_hooks
.
pop
().
remove
()
# Go through the parameters, attach the hook
# Go through the parameters, attach the hook
self
.
_grad_accs
=
[]
self
.
_grad_accs
=
[]
self
.
_manual_reduce
=
[]
self
.
_manual_reduce
=
[]
if
not
self
.
training
:
if
not
self
.
training
:
return
return
for
index
,
param
in
enumerate
(
self
.
_trainable_params
):
for
index
,
param
in
enumerate
(
self
.
_trainable_params
):
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
p_tmp
=
param
.
expand_as
(
param
)
p_tmp
=
param
.
expand_as
(
param
)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
# these do not need to be sync'ed
if
p_tmp
.
grad_fn
is
not
None
:
if
p_tmp
.
grad_fn
is
not
None
:
# Register the hook to the next function in line,
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
self
.
_manual_reduce
.
append
(
reduce_function
)
self
.
_manual_reduce
.
append
(
reduce_function
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
def
_sync_params_and_buffers
(
self
)
->
None
:
...
@@ -552,41 +556,42 @@ class ShardedDataParallel(nn.Module):
...
@@ -552,41 +556,42 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
"""
if
not
self
.
_use_buckets
:
with
profiler
.
record_function
(
"fairscale::sdp::setup_buckets"
):
return
if
not
self
.
_use_buckets
:
return
# Devise the bucketing strategy. Parameters are already sorted, in that:
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
# - they are sorted by increasing size
self
.
_buckets
=
{}
self
.
_buckets
=
{}
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
device
=
param
.
device
device
=
param
.
device
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
if
param
.
device
not
in
self
.
_buckets
.
keys
():
if
param
.
device
not
in
self
.
_buckets
.
keys
():
self
.
_buckets
[
param
.
device
]
=
{}
self
.
_buckets
[
param
.
device
]
=
{}
if
dst_rank
not
in
self
.
_buckets
[
param
.
device
].
keys
():
if
dst_rank
not
in
self
.
_buckets
[
param
.
device
].
keys
():
self
.
_buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
_buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
_buffer_max_size
,
self
.
_buffer_max_size
,
dtype
=
param
.
dtype
,
dtype
=
param
.
dtype
,
device
=
param
.
device
,
device
=
param
.
device
,
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
)
)
# Criteria to decide whether this parameter is to be bucketed or not:
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
# - enough room in the bucket
if
self
.
_buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
if
self
.
_buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
self
.
_buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_should_bucket_grad
[
i
]
=
True
self
.
_should_bucket_grad
[
i
]
=
True
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
_buckets
[
device
].
values
()
for
device
in
self
.
_buckets
.
keys
()]))
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
_buckets
[
device
].
values
()
for
device
in
self
.
_buckets
.
keys
()]))
# Resize the buckets to remove lost space in the end
# Resize the buckets to remove lost space in the end
for
bucket
in
self
.
_bucket_list
:
for
bucket
in
self
.
_bucket_list
:
bucket
.
shrink
()
bucket
.
shrink
()
def
_consume_work_handles
(
self
)
->
None
:
def
_consume_work_handles
(
self
)
->
None
:
"""Consume all the futures which are tied to this optimizer's buckets.
"""Consume all the futures which are tied to this optimizer's buckets.
...
@@ -628,19 +633,20 @@ class ShardedDataParallel(nn.Module):
...
@@ -628,19 +633,20 @@ class ShardedDataParallel(nn.Module):
self
.
_consume_work_handles
()
self
.
_consume_work_handles
()
def
_detect_train_change
(
self
)
->
bool
:
def
_detect_train_change
(
self
)
->
bool
:
# Optionally check whether the trainable parameters have changed
with
profiler
.
record_function
(
"fairscale::sdp::detect_train_changes"
):
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# Optionally check whether the trainable parameters have changed
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - one or more parameters trainability changed
# - one or more parameters trainability changed
trainability_changed
=
trainable_mask
!=
self
.
_reference_trainable_mask
trainability_changed
=
trainable_mask
!=
self
.
_reference_trainable_mask
# - the whole model is not trainable but we still have grad hooks
# - the whole model is not trainable but we still have grad hooks
trainability_changed
|=
not
self
.
training
and
len
(
self
.
_grad_hooks
)
>
0
trainability_changed
|=
not
self
.
training
and
len
(
self
.
_grad_hooks
)
>
0
if
trainability_changed
:
if
trainability_changed
:
logging
.
warning
(
logging
.
warning
(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
)
)
self
.
_reference_trainable_mask
=
trainable_mask
self
.
_reference_trainable_mask
=
trainable_mask
return
trainability_changed
return
trainability_changed
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment