Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
c084b202
Unverified
Commit
c084b202
authored
Apr 15, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 15, 2021
Browse files
[chore][SDP] privatizing all the things (#611)
parent
a77c56f0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
72 deletions
+74
-72
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+74
-72
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
c084b202
...
...
@@ -102,41 +102,40 @@ class ShardedDataParallel(nn.Module):
):
super
().
__init__
()
self
.
module
=
module
self
.
sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
enable_broadcast_buffers
=
broadcast_buffers
self
.
auto_refresh_trainable
=
auto_refresh_trainable
self
.
reduce_fp16
=
reduce_fp16
self
.
_
module
=
module
self
.
_
sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_
enable_broadcast_buffers
=
broadcast_buffers
self
.
_
auto_refresh_trainable
=
auto_refresh_trainable
self
.
_
reduce_fp16
=
reduce_fp16
if
reduce_buffer_size
>
0
and
reduce_fp16
:
self
.
reduce_fp16
=
False
self
.
_
reduce_fp16
=
False
logging
.
warning
(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
)
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self
.
should_accumulate_grads
=
False
self
.
accumulate_grads_flipped
=
False
self
.
_
should_accumulate_grads
=
False
self
.
_
accumulate_grads_flipped
=
False
# Communication related attributes
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
reference_global_rank
=
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
global_rank
=
get_global_rank
(
self
.
process_group
,
self
.
rank
)
self
.
_
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
_
backend
=
dist
.
get_backend
(
self
.
_
process_group
)
self
.
_
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
_
process_group
)
# > 0
self
.
_
reference_global_rank
=
get_global_rank
(
self
.
_
process_group
,
0
)
# picking rank 0 as the reference
self
.
_
rank
=
dist
.
get_rank
(
self
.
_
process_group
)
self
.
_
global_rank
=
get_global_rank
(
self
.
_
process_group
,
self
.
_
rank
)
self
.
_local_to_global_rank
=
[
get_global_rank
(
self
.
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
process_group
))
get_global_rank
(
self
.
_
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
_
process_group
))
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
devices
=
{
p
.
device
for
p
in
self
.
module
.
parameters
()}
devices
=
{
p
.
device
for
p
in
self
.
_
module
.
parameters
()}
self
.
is_multi_device_module
=
len
(
devices
)
>
1
self
.
device
=
list
(
devices
)[
0
]
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
module
.
parameters
()}
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
_
module
.
parameters
()}
assert
len
(
distinct_device_types
)
==
1
,
(
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
...
...
@@ -149,7 +148,10 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally
self
.
_all_params
=
list
(
chain
(
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
_per_device_params
.
values
()],
[])
for
optim
in
self
.
sharded_optimizers
]
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
_per_device_params
.
values
()],
[])
for
optim
in
self
.
_sharded_optimizers
]
)
)
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
...
...
@@ -158,21 +160,21 @@ class ShardedDataParallel(nn.Module):
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - setup buckets and tensor views
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
module
.
parameters
()])
self
.
buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
_
module
.
parameters
()])
self
.
_
buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
if
dist
.
get_world_size
(
self
.
process_group
)
==
1
:
self
.
buffer_max_size
=
0
if
dist
.
get_world_size
(
self
.
_
process_group
)
==
1
:
self
.
_
buffer_max_size
=
0
logging
.
info
(
"Training is not really distributed, single rank. Deactivating buckets"
)
logging
.
info
(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
buffer_max_size
/
2
**
20
,
model_size
/
2
**
20
self
.
_
buffer_max_size
/
2
**
20
,
model_size
/
2
**
20
)
)
self
.
use_buckets
=
self
.
buffer_max_size
>
0
self
.
_
use_buckets
=
self
.
_
buffer_max_size
>
0
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
GradBucket
]]
=
{}
self
.
_
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
GradBucket
]]
=
{}
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_bucket_list
:
List
[
GradBucket
]
=
[]
...
...
@@ -182,7 +184,7 @@ class ShardedDataParallel(nn.Module):
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
# passing a handle to torch.nn.SyncBatchNorm layer
self
.
_passing_sync_batchnorm_handle
(
self
.
module
)
self
.
_passing_sync_batchnorm_handle
(
self
.
_
module
)
# Make sure that all ranks start with the same model
if
sync_models_at_startup
:
...
...
@@ -200,13 +202,13 @@ class ShardedDataParallel(nn.Module):
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
if
self
.
auto_refresh_trainable
:
if
self
.
_
auto_refresh_trainable
:
needs_setup
|=
self
.
_detect_train_change
()
if
needs_setup
:
self
.
refresh_trainable
()
if
self
.
enable_broadcast_buffers
:
if
self
.
_
enable_broadcast_buffers
:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self
.
sync_buffers
(
blocking
=
True
)
...
...
@@ -215,7 +217,7 @@ class ShardedDataParallel(nn.Module):
self
.
_clear_counters
()
# Normal FW on the base model
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
_
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
self
,
...
...
@@ -252,16 +254,16 @@ class ShardedDataParallel(nn.Module):
Module: self.
"""
assert
device
in
self
.
buckets
.
keys
(),
"Changing devices is not supported, because this would break OSSs state"
assert
device
in
self
.
_
buckets
.
keys
(),
"Changing devices is not supported, because this would break OSSs state"
assert
(
len
(
self
.
buckets
.
keys
())
==
1
len
(
self
.
_
buckets
.
keys
())
==
1
),
"Several devices specified to begin with, incompatible with setting a single device here"
for
_device
in
self
.
buckets
.
keys
():
for
bucket
in
self
.
buckets
[
_device
].
values
():
for
_device
in
self
.
_
buckets
.
keys
():
for
bucket
in
self
.
_
buckets
[
_device
].
values
():
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
_
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
def
refresh_trainable
(
self
)
->
None
:
""" If the module trainability has changed, update all the assumptions """
...
...
@@ -276,7 +278,7 @@ class ShardedDataParallel(nn.Module):
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_param_to_rank
=
{}
for
optim
in
self
.
sharded_optimizers
:
for
optim
in
self
.
_
sharded_optimizers
:
# OSS may need to change the communication pattern
optim
.
refresh_trainable
()
...
...
@@ -320,13 +322,13 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
_
module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
dist
.
broadcast
(
buffer
.
data
,
self
.
_
reference_global_rank
,
self
.
_
process_group
,
async_op
=
True
)
)
if
blocking
and
work_handles
:
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
if
self
.
_
backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
work_handles
[
-
1
].
wait
()
...
...
@@ -354,16 +356,16 @@ class ShardedDataParallel(nn.Module):
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
return
getattr
(
self
.
_
module
,
name
)
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
"""A context manager to disable gradient synchronization."""
old_should_accumulate_grads
=
self
.
should_accumulate_grads
self
.
should_accumulate_grads
=
True
old_should_accumulate_grads
=
self
.
_
should_accumulate_grads
self
.
_
should_accumulate_grads
=
True
yield
self
.
accumulate_grads_flipped
=
self
.
should_accumulate_grads
!=
old_should_accumulate_grads
self
.
should_accumulate_grads
=
old_should_accumulate_grads
self
.
_
accumulate_grads_flipped
=
self
.
_
should_accumulate_grads
!=
old_should_accumulate_grads
self
.
_
should_accumulate_grads
=
old_should_accumulate_grads
@
torch
.
no_grad
()
def
_clear_counters
(
self
)
->
None
:
...
...
@@ -372,12 +374,12 @@ class ShardedDataParallel(nn.Module):
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_trainable_params
]
self
.
_bucket_flush_callback_set
=
False
if
self
.
use_buckets
:
if
self
.
_
use_buckets
:
for
bucket
in
self
.
_bucket_list
:
bucket
.
reset_checked_in
()
if
not
self
.
should_accumulate_grads
:
self
.
accumulate_grads_flipped
=
False
if
not
self
.
_
should_accumulate_grads
:
self
.
_
accumulate_grads_flipped
=
False
def
_get_reduce_fn
(
self
,
index
:
int
,
param
:
torch
.
Tensor
,
dst_rank
:
int
)
->
Callable
:
"""
...
...
@@ -387,12 +389,12 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
if
not
self
.
use_buckets
or
not
self
.
_should_bucket_grad
[
index
]:
if
not
self
.
_
use_buckets
or
not
self
.
_should_bucket_grad
[
index
]:
# Direct reduction
@
torch
.
no_grad
()
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
if
not
self
.
_
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
assert
param
.
grad
is
not
None
,
"Reducing gradients during backward pass, cannot be None"
if
not
self
.
_bucket_flush_callback_set
:
...
...
@@ -401,14 +403,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
self
.
_grad_to_be_reduced
[
index
]
=
False
param
.
grad
.
mul_
(
self
.
world_size_scaling
)
param
.
grad
.
mul_
(
self
.
_
world_size_scaling
)
if
self
.
reduce_fp16
:
if
self
.
_
reduce_fp16
:
param
.
grad
.
data
=
param
.
grad
.
data
.
half
()
# Future work includes clearing up the buffer if possible
def
cleanup
()
->
None
:
if
dst_rank
!=
self
.
global_rank
:
if
dst_rank
!=
self
.
_
global_rank
:
param
.
grad
=
None
else
:
assert
param
.
grad
is
not
None
...
...
@@ -420,7 +422,7 @@ class ShardedDataParallel(nn.Module):
handle
=
dist
.
reduce
(
tensor
=
param
.
grad
.
data
,
dst
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
process_group
,
group
=
self
.
_
process_group
,
async_op
=
True
,
),
callback
=
cleanup
,
...
...
@@ -436,7 +438,7 @@ class ShardedDataParallel(nn.Module):
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
if
not
self
.
_
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
assert
param
.
grad
is
not
None
,
"Reducing gradients during backward pass, cannot be None"
if
not
self
.
_bucket_flush_callback_set
:
...
...
@@ -445,14 +447,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
self
.
_grad_to_be_reduced
[
index
]
=
False
bucket
=
self
.
buckets
[
param
.
device
][
dst_rank
]
bucket
=
self
.
_
buckets
[
param
.
device
][
dst_rank
]
bucket
.
params_checked_in
+=
1
if
bucket
.
all_checked_in
:
assert
bucket
.
buffer
is
not
None
# Normalize the bucket in one go
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
buffer
.
mul_
(
self
.
_
world_size_scaling
)
# Reduce the bucket
bucket
.
sent
=
True
...
...
@@ -461,7 +463,7 @@ class ShardedDataParallel(nn.Module):
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
group
=
self
.
_
process_group
,
async_op
=
True
,
),
callback
=
None
,
...
...
@@ -520,13 +522,13 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
for
t
in
self
.
module
.
state_dict
().
values
():
for
t
in
self
.
_
module
.
state_dict
().
values
():
work_handles
.
append
(
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
dist
.
broadcast
(
t
,
src
=
self
.
_
reference_global_rank
,
group
=
self
.
_
process_group
,
async_op
=
True
)
)
# gloo does not guarantee inlining like NCCL, wait for all requests
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
if
self
.
_
backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
elif
work_handles
:
work_handles
[
-
1
].
wait
()
...
...
@@ -549,25 +551,25 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
if
not
self
.
use_buckets
:
if
not
self
.
_
use_buckets
:
return
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self
.
buckets
=
{}
self
.
_
buckets
=
{}
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
device
=
param
.
device
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
if
param
.
device
not
in
self
.
buckets
.
keys
():
self
.
buckets
[
param
.
device
]
=
{}
if
param
.
device
not
in
self
.
_
buckets
.
keys
():
self
.
_
buckets
[
param
.
device
]
=
{}
if
dst_rank
not
in
self
.
buckets
[
param
.
device
].
keys
():
self
.
buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
buffer_max_size
,
if
dst_rank
not
in
self
.
_
buckets
[
param
.
device
].
keys
():
self
.
_
buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
_
buffer_max_size
,
dtype
=
param
.
dtype
,
device
=
param
.
device
,
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
...
...
@@ -575,11 +577,11 @@ class ShardedDataParallel(nn.Module):
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if
self
.
buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
self
.
buckets
[
device
][
dst_rank
].
add_grad
(
param
)
if
self
.
_
buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
self
.
_
buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_should_bucket_grad
[
i
]
=
True
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
buckets
[
device
].
values
()
for
device
in
self
.
buckets
.
keys
()]))
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
_
buckets
[
device
].
values
()
for
device
in
self
.
_
buckets
.
keys
()]))
# Resize the buckets to remove lost space in the end
for
bucket
in
self
.
_bucket_list
:
...
...
@@ -609,13 +611,13 @@ class ShardedDataParallel(nn.Module):
assert
bucket
.
buffer
is
not
None
# Normalize the bucket in one go
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
buffer
.
mul_
(
self
.
_
world_size_scaling
)
# Reduce the bucket
self
.
_work_handles
.
append
(
Workhandle
(
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
async_op
=
True
,
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
_
process_group
,
async_op
=
True
,
),
callback
=
None
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment