Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
c084b202
Unverified
Commit
c084b202
authored
Apr 15, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 15, 2021
Browse files
[chore][SDP] privatizing all the things (#611)
parent
a77c56f0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
72 deletions
+74
-72
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+74
-72
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
c084b202
...
@@ -102,41 +102,40 @@ class ShardedDataParallel(nn.Module):
...
@@ -102,41 +102,40 @@ class ShardedDataParallel(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
_
module
=
module
self
.
sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
_
sharded_optimizers
=
[
sharded_optimizer
]
if
not
isinstance
(
sharded_optimizer
,
list
)
else
sharded_optimizer
self
.
enable_broadcast_buffers
=
broadcast_buffers
self
.
_
enable_broadcast_buffers
=
broadcast_buffers
self
.
auto_refresh_trainable
=
auto_refresh_trainable
self
.
_
auto_refresh_trainable
=
auto_refresh_trainable
self
.
reduce_fp16
=
reduce_fp16
self
.
_
reduce_fp16
=
reduce_fp16
if
reduce_buffer_size
>
0
and
reduce_fp16
:
if
reduce_buffer_size
>
0
and
reduce_fp16
:
self
.
reduce_fp16
=
False
self
.
_
reduce_fp16
=
False
logging
.
warning
(
logging
.
warning
(
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
"fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated."
)
)
# Handle a no_sync() context which prevents the gradient synchronization,
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
# accumulate in place
self
.
should_accumulate_grads
=
False
self
.
_
should_accumulate_grads
=
False
self
.
accumulate_grads_flipped
=
False
self
.
_
accumulate_grads_flipped
=
False
# Communication related attributes
# Communication related attributes
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
_
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
_
backend
=
dist
.
get_backend
(
self
.
_
process_group
)
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
_
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
_
process_group
)
# > 0
self
.
reference_global_rank
=
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
_
reference_global_rank
=
get_global_rank
(
self
.
_
process_group
,
0
)
# picking rank 0 as the reference
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
_
rank
=
dist
.
get_rank
(
self
.
_
process_group
)
self
.
global_rank
=
get_global_rank
(
self
.
process_group
,
self
.
rank
)
self
.
_
global_rank
=
get_global_rank
(
self
.
_
process_group
,
self
.
_
rank
)
self
.
_local_to_global_rank
=
[
self
.
_local_to_global_rank
=
[
get_global_rank
(
self
.
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
process_group
))
get_global_rank
(
self
.
_
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
_
process_group
))
]
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
# device_id related logic is not present, this is not handled
devices
=
{
p
.
device
for
p
in
self
.
module
.
parameters
()}
devices
=
{
p
.
device
for
p
in
self
.
_
module
.
parameters
()}
self
.
is_multi_device_module
=
len
(
devices
)
>
1
self
.
is_multi_device_module
=
len
(
devices
)
>
1
self
.
device
=
list
(
devices
)[
0
]
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
module
.
parameters
()}
distinct_device_types
=
{
p
.
device
.
type
for
p
in
self
.
_
module
.
parameters
()}
assert
len
(
distinct_device_types
)
==
1
,
(
assert
len
(
distinct_device_types
)
==
1
,
(
"ShardedDataParallel's input module must be on "
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
"the same type of devices, but input module parameters are located on {} different device types."
...
@@ -149,7 +148,10 @@ class ShardedDataParallel(nn.Module):
...
@@ -149,7 +148,10 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally
# - we build an iterator which goes through all the parameters involved globally
self
.
_all_params
=
list
(
self
.
_all_params
=
list
(
chain
(
chain
(
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
_per_device_params
.
values
()],
[])
for
optim
in
self
.
sharded_optimizers
]
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
_per_device_params
.
values
()],
[])
for
optim
in
self
.
_sharded_optimizers
]
)
)
)
)
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -158,21 +160,21 @@ class ShardedDataParallel(nn.Module):
...
@@ -158,21 +160,21 @@ class ShardedDataParallel(nn.Module):
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
self
.
_reference_trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
# - setup buckets and tensor views
# - setup buckets and tensor views
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
module
.
parameters
()])
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
_
module
.
parameters
()])
self
.
buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
self
.
_
buffer_max_size
=
min
(
reduce_buffer_size
,
model_size
)
if
dist
.
get_world_size
(
self
.
process_group
)
==
1
:
if
dist
.
get_world_size
(
self
.
_
process_group
)
==
1
:
self
.
buffer_max_size
=
0
self
.
_
buffer_max_size
=
0
logging
.
info
(
"Training is not really distributed, single rank. Deactivating buckets"
)
logging
.
info
(
"Training is not really distributed, single rank. Deactivating buckets"
)
logging
.
info
(
logging
.
info
(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
buffer_max_size
/
2
**
20
,
model_size
/
2
**
20
self
.
_
buffer_max_size
/
2
**
20
,
model_size
/
2
**
20
)
)
)
)
self
.
use_buckets
=
self
.
buffer_max_size
>
0
self
.
_
use_buckets
=
self
.
_
buffer_max_size
>
0
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
GradBucket
]]
=
{}
self
.
_
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
GradBucket
]]
=
{}
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_should_bucket_grad
:
List
[
bool
]
=
[]
self
.
_bucket_list
:
List
[
GradBucket
]
=
[]
self
.
_bucket_list
:
List
[
GradBucket
]
=
[]
...
@@ -182,7 +184,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -182,7 +184,7 @@ class ShardedDataParallel(nn.Module):
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
self
.
_manual_reduce
:
List
[
Callable
]
=
[]
# passing a handle to torch.nn.SyncBatchNorm layer
# passing a handle to torch.nn.SyncBatchNorm layer
self
.
_passing_sync_batchnorm_handle
(
self
.
module
)
self
.
_passing_sync_batchnorm_handle
(
self
.
_
module
)
# Make sure that all ranks start with the same model
# Make sure that all ranks start with the same model
if
sync_models_at_startup
:
if
sync_models_at_startup
:
...
@@ -200,13 +202,13 @@ class ShardedDataParallel(nn.Module):
...
@@ -200,13 +202,13 @@ class ShardedDataParallel(nn.Module):
# Deferred initialization, or change detection
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
if
self
.
auto_refresh_trainable
:
if
self
.
_
auto_refresh_trainable
:
needs_setup
|=
self
.
_detect_train_change
()
needs_setup
|=
self
.
_detect_train_change
()
if
needs_setup
:
if
needs_setup
:
self
.
refresh_trainable
()
self
.
refresh_trainable
()
if
self
.
enable_broadcast_buffers
:
if
self
.
_
enable_broadcast_buffers
:
# NCCL communications are on a different stream, needs to be blocking
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
# for the subsequent FW to be correct
self
.
sync_buffers
(
blocking
=
True
)
self
.
sync_buffers
(
blocking
=
True
)
...
@@ -215,7 +217,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -215,7 +217,7 @@ class ShardedDataParallel(nn.Module):
self
.
_clear_counters
()
self
.
_clear_counters
()
# Normal FW on the base model
# Normal FW on the base model
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
_
module
(
*
inputs
,
**
kwargs
)
def
to
(
# type: ignore
def
to
(
# type: ignore
self
,
self
,
...
@@ -252,16 +254,16 @@ class ShardedDataParallel(nn.Module):
...
@@ -252,16 +254,16 @@ class ShardedDataParallel(nn.Module):
Module: self.
Module: self.
"""
"""
assert
device
in
self
.
buckets
.
keys
(),
"Changing devices is not supported, because this would break OSSs state"
assert
device
in
self
.
_
buckets
.
keys
(),
"Changing devices is not supported, because this would break OSSs state"
assert
(
assert
(
len
(
self
.
buckets
.
keys
())
==
1
len
(
self
.
_
buckets
.
keys
())
==
1
),
"Several devices specified to begin with, incompatible with setting a single device here"
),
"Several devices specified to begin with, incompatible with setting a single device here"
for
_device
in
self
.
buckets
.
keys
():
for
_device
in
self
.
_
buckets
.
keys
():
for
bucket
in
self
.
buckets
[
_device
].
values
():
for
bucket
in
self
.
_
buckets
[
_device
].
values
():
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
bucket
.
to
(
device
=
_device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
self
.
_
module
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
def
refresh_trainable
(
self
)
->
None
:
def
refresh_trainable
(
self
)
->
None
:
""" If the module trainability has changed, update all the assumptions """
""" If the module trainability has changed, update all the assumptions """
...
@@ -276,7 +278,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -276,7 +278,7 @@ class ShardedDataParallel(nn.Module):
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_param_to_rank
=
{}
self
.
_trainable_param_to_rank
=
{}
for
optim
in
self
.
sharded_optimizers
:
for
optim
in
self
.
_
sharded_optimizers
:
# OSS may need to change the communication pattern
# OSS may need to change the communication pattern
optim
.
refresh_trainable
()
optim
.
refresh_trainable
()
...
@@ -320,13 +322,13 @@ class ShardedDataParallel(nn.Module):
...
@@ -320,13 +322,13 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
work_handles
=
[]
for
buffer
in
self
.
module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
_
module
.
buffers
(
recurse
=
True
):
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
buffer
.
data
,
self
.
reference_global_rank
,
self
.
process_group
,
async_op
=
True
)
dist
.
broadcast
(
buffer
.
data
,
self
.
_
reference_global_rank
,
self
.
_
process_group
,
async_op
=
True
)
)
)
if
blocking
and
work_handles
:
if
blocking
and
work_handles
:
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
if
self
.
_
backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
else
:
else
:
work_handles
[
-
1
].
wait
()
work_handles
[
-
1
].
wait
()
...
@@ -354,16 +356,16 @@ class ShardedDataParallel(nn.Module):
...
@@ -354,16 +356,16 @@ class ShardedDataParallel(nn.Module):
try
:
try
:
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
return
super
().
__getattr__
(
name
)
# defer to nn.Module's logic
except
AttributeError
:
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
return
getattr
(
self
.
_
module
,
name
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
"""A context manager to disable gradient synchronization."""
"""A context manager to disable gradient synchronization."""
old_should_accumulate_grads
=
self
.
should_accumulate_grads
old_should_accumulate_grads
=
self
.
_
should_accumulate_grads
self
.
should_accumulate_grads
=
True
self
.
_
should_accumulate_grads
=
True
yield
yield
self
.
accumulate_grads_flipped
=
self
.
should_accumulate_grads
!=
old_should_accumulate_grads
self
.
_
accumulate_grads_flipped
=
self
.
_
should_accumulate_grads
!=
old_should_accumulate_grads
self
.
should_accumulate_grads
=
old_should_accumulate_grads
self
.
_
should_accumulate_grads
=
old_should_accumulate_grads
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_clear_counters
(
self
)
->
None
:
def
_clear_counters
(
self
)
->
None
:
...
@@ -372,12 +374,12 @@ class ShardedDataParallel(nn.Module):
...
@@ -372,12 +374,12 @@ class ShardedDataParallel(nn.Module):
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_trainable_params
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_trainable_params
]
self
.
_bucket_flush_callback_set
=
False
self
.
_bucket_flush_callback_set
=
False
if
self
.
use_buckets
:
if
self
.
_
use_buckets
:
for
bucket
in
self
.
_bucket_list
:
for
bucket
in
self
.
_bucket_list
:
bucket
.
reset_checked_in
()
bucket
.
reset_checked_in
()
if
not
self
.
should_accumulate_grads
:
if
not
self
.
_
should_accumulate_grads
:
self
.
accumulate_grads_flipped
=
False
self
.
_
accumulate_grads_flipped
=
False
def
_get_reduce_fn
(
self
,
index
:
int
,
param
:
torch
.
Tensor
,
dst_rank
:
int
)
->
Callable
:
def
_get_reduce_fn
(
self
,
index
:
int
,
param
:
torch
.
Tensor
,
dst_rank
:
int
)
->
Callable
:
"""
"""
...
@@ -387,12 +389,12 @@ class ShardedDataParallel(nn.Module):
...
@@ -387,12 +389,12 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
Either way a delayed action is necessary and is passed as a callback.
"""
"""
if
not
self
.
use_buckets
or
not
self
.
_should_bucket_grad
[
index
]:
if
not
self
.
_
use_buckets
or
not
self
.
_should_bucket_grad
[
index
]:
# Direct reduction
# Direct reduction
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
reduce
(
*
_
:
Any
)
->
None
:
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
if
not
self
.
_
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
assert
param
.
grad
is
not
None
,
"Reducing gradients during backward pass, cannot be None"
assert
param
.
grad
is
not
None
,
"Reducing gradients during backward pass, cannot be None"
if
not
self
.
_bucket_flush_callback_set
:
if
not
self
.
_bucket_flush_callback_set
:
...
@@ -401,14 +403,14 @@ class ShardedDataParallel(nn.Module):
...
@@ -401,14 +403,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
# Make sure that this is not fired twice
self
.
_grad_to_be_reduced
[
index
]
=
False
self
.
_grad_to_be_reduced
[
index
]
=
False
param
.
grad
.
mul_
(
self
.
world_size_scaling
)
param
.
grad
.
mul_
(
self
.
_
world_size_scaling
)
if
self
.
reduce_fp16
:
if
self
.
_
reduce_fp16
:
param
.
grad
.
data
=
param
.
grad
.
data
.
half
()
param
.
grad
.
data
=
param
.
grad
.
data
.
half
()
# Future work includes clearing up the buffer if possible
# Future work includes clearing up the buffer if possible
def
cleanup
()
->
None
:
def
cleanup
()
->
None
:
if
dst_rank
!=
self
.
global_rank
:
if
dst_rank
!=
self
.
_
global_rank
:
param
.
grad
=
None
param
.
grad
=
None
else
:
else
:
assert
param
.
grad
is
not
None
assert
param
.
grad
is
not
None
...
@@ -420,7 +422,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -420,7 +422,7 @@ class ShardedDataParallel(nn.Module):
handle
=
dist
.
reduce
(
handle
=
dist
.
reduce
(
tensor
=
param
.
grad
.
data
,
tensor
=
param
.
grad
.
data
,
dst
=
self
.
_local_to_global_rank
[
dst_rank
],
dst
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
process_group
,
group
=
self
.
_
process_group
,
async_op
=
True
,
async_op
=
True
,
),
),
callback
=
cleanup
,
callback
=
cleanup
,
...
@@ -436,7 +438,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -436,7 +438,7 @@ class ShardedDataParallel(nn.Module):
def
reduce
(
*
_
:
Any
)
->
None
:
def
reduce
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
if
not
self
.
_
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
assert
param
.
grad
is
not
None
,
"Reducing gradients during backward pass, cannot be None"
assert
param
.
grad
is
not
None
,
"Reducing gradients during backward pass, cannot be None"
if
not
self
.
_bucket_flush_callback_set
:
if
not
self
.
_bucket_flush_callback_set
:
...
@@ -445,14 +447,14 @@ class ShardedDataParallel(nn.Module):
...
@@ -445,14 +447,14 @@ class ShardedDataParallel(nn.Module):
# Make sure that this is not fired twice
# Make sure that this is not fired twice
self
.
_grad_to_be_reduced
[
index
]
=
False
self
.
_grad_to_be_reduced
[
index
]
=
False
bucket
=
self
.
buckets
[
param
.
device
][
dst_rank
]
bucket
=
self
.
_
buckets
[
param
.
device
][
dst_rank
]
bucket
.
params_checked_in
+=
1
bucket
.
params_checked_in
+=
1
if
bucket
.
all_checked_in
:
if
bucket
.
all_checked_in
:
assert
bucket
.
buffer
is
not
None
assert
bucket
.
buffer
is
not
None
# Normalize the bucket in one go
# Normalize the bucket in one go
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
buffer
.
mul_
(
self
.
_
world_size_scaling
)
# Reduce the bucket
# Reduce the bucket
bucket
.
sent
=
True
bucket
.
sent
=
True
...
@@ -461,7 +463,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -461,7 +463,7 @@ class ShardedDataParallel(nn.Module):
handle
=
dist
.
reduce
(
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
group
=
self
.
_
process_group
,
async_op
=
True
,
async_op
=
True
,
),
),
callback
=
None
,
callback
=
None
,
...
@@ -520,13 +522,13 @@ class ShardedDataParallel(nn.Module):
...
@@ -520,13 +522,13 @@ class ShardedDataParallel(nn.Module):
work_handles
=
[]
work_handles
=
[]
for
t
in
self
.
module
.
state_dict
().
values
():
for
t
in
self
.
_
module
.
state_dict
().
values
():
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
t
,
src
=
self
.
reference_global_rank
,
group
=
self
.
process_group
,
async_op
=
True
)
dist
.
broadcast
(
t
,
src
=
self
.
_
reference_global_rank
,
group
=
self
.
_
process_group
,
async_op
=
True
)
)
)
# gloo does not guarantee inlining like NCCL, wait for all requests
# gloo does not guarantee inlining like NCCL, wait for all requests
if
self
.
backend
!=
dist
.
Backend
.
NCCL
:
if
self
.
_
backend
!=
dist
.
Backend
.
NCCL
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
elif
work_handles
:
elif
work_handles
:
work_handles
[
-
1
].
wait
()
work_handles
[
-
1
].
wait
()
...
@@ -549,25 +551,25 @@ class ShardedDataParallel(nn.Module):
...
@@ -549,25 +551,25 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
"""
if
not
self
.
use_buckets
:
if
not
self
.
_
use_buckets
:
return
return
# Devise the bucketing strategy. Parameters are already sorted, in that:
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
# - they are sorted by increasing size
self
.
buckets
=
{}
self
.
_
buckets
=
{}
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
self
.
_should_bucket_grad
=
[
False
for
_
in
self
.
_trainable_params
]
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
for
i
,
param
in
enumerate
(
self
.
_trainable_params
):
device
=
param
.
device
device
=
param
.
device
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
if
param
.
device
not
in
self
.
buckets
.
keys
():
if
param
.
device
not
in
self
.
_
buckets
.
keys
():
self
.
buckets
[
param
.
device
]
=
{}
self
.
_
buckets
[
param
.
device
]
=
{}
if
dst_rank
not
in
self
.
buckets
[
param
.
device
].
keys
():
if
dst_rank
not
in
self
.
_
buckets
[
param
.
device
].
keys
():
self
.
buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
_
buckets
[
param
.
device
][
dst_rank
]
=
GradBucket
(
self
.
buffer_max_size
,
self
.
_
buffer_max_size
,
dtype
=
param
.
dtype
,
dtype
=
param
.
dtype
,
device
=
param
.
device
,
device
=
param
.
device
,
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
destination
=
self
.
_local_to_global_rank
[
dst_rank
],
...
@@ -575,11 +577,11 @@ class ShardedDataParallel(nn.Module):
...
@@ -575,11 +577,11 @@ class ShardedDataParallel(nn.Module):
# Criteria to decide whether this parameter is to be bucketed or not:
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
# - enough room in the bucket
if
self
.
buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
if
self
.
_
buckets
[
device
][
dst_rank
].
can_add_grad_view
(
param
):
self
.
buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_
buckets
[
device
][
dst_rank
].
add_grad
(
param
)
self
.
_should_bucket_grad
[
i
]
=
True
self
.
_should_bucket_grad
[
i
]
=
True
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
buckets
[
device
].
values
()
for
device
in
self
.
buckets
.
keys
()]))
self
.
_bucket_list
=
list
(
chain
(
*
[
self
.
_
buckets
[
device
].
values
()
for
device
in
self
.
_
buckets
.
keys
()]))
# Resize the buckets to remove lost space in the end
# Resize the buckets to remove lost space in the end
for
bucket
in
self
.
_bucket_list
:
for
bucket
in
self
.
_bucket_list
:
...
@@ -609,13 +611,13 @@ class ShardedDataParallel(nn.Module):
...
@@ -609,13 +611,13 @@ class ShardedDataParallel(nn.Module):
assert
bucket
.
buffer
is
not
None
assert
bucket
.
buffer
is
not
None
# Normalize the bucket in one go
# Normalize the bucket in one go
bucket
.
buffer
.
mul_
(
self
.
world_size_scaling
)
bucket
.
buffer
.
mul_
(
self
.
_
world_size_scaling
)
# Reduce the bucket
# Reduce the bucket
self
.
_work_handles
.
append
(
self
.
_work_handles
.
append
(
Workhandle
(
Workhandle
(
handle
=
dist
.
reduce
(
handle
=
dist
.
reduce
(
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
process_group
,
async_op
=
True
,
tensor
=
bucket
.
buffer
,
dst
=
bucket
.
destination
,
group
=
self
.
_
process_group
,
async_op
=
True
,
),
),
callback
=
None
,
callback
=
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment