Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
79ded821
Unverified
Commit
79ded821
authored
Sep 28, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Sep 28, 2020
Browse files
[ShardedDDP] Sync buffers + small cleanup (#112)
- adding the buffer broadcast option - minor cleanup in shardedDDP
parent
41819af9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
42 deletions
+59
-42
benchmarks/oss.py
benchmarks/oss.py
+6
-2
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+40
-34
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+13
-6
No files found.
benchmarks/oss.py
View file @
79ded821
...
...
@@ -62,7 +62,7 @@ def train(
):
assert
not
use_sdp
or
(
use_sdp
and
use_oss
),
"ShardedDataParallel requires OSS"
# DDP
dist_init
(
rank
,
world_size
,
backend
)
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
)
# Setup
torch
.
cuda
.
set_device
(
rank
)
...
...
@@ -81,7 +81,11 @@ def train(
if
use_sdp
:
ddp
=
ShardedDataParallel
(
module
=
model
,
optimizer
=
OPTIM
,
optimizer_params
=
{
"lr"
:
1e-4
,
"momentum"
:
0.9
},
world_size
=
world_size
,
module
=
model
,
optimizer
=
OPTIM
,
optimizer_params
=
{
"lr"
:
1e-4
,
"momentum"
:
0.9
},
world_size
=
world_size
,
broadcast_buffers
=
False
,
)
ddp
.
train
()
optimizer
=
ddp
.
optimizer
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
79ded821
...
...
@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and
does not
This version uses a c10d process group for communication and
optionally
broadcast buffers.
Args:
...
...
@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module):
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
will be used.
...
...
@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module):
optimizer
:
Type
[
torch
.
optim
.
Optimizer
],
optimizer_params
:
Dict
[
str
,
Any
],
world_size
:
int
,
broadcast_buffers
:
bool
,
process_group
:
Any
=
None
,
buffer_size
:
int
=
2
**
28
,
):
...
...
@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module):
self
.
world_size
=
world_size
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
broadcast_buffers
=
broadcast_buffers
self
.
authoritative_rank
=
0
# Never use a bigger buffer than the number of model params
self
.
buffer_size
=
min
(
buffer_size
,
sum
(
p
.
numel
()
for
p
in
self
.
module
.
parameters
()))
...
...
@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer
self
.
sharded_optimizer
=
OSS
(
self
.
module
.
parameters
(),
optim
=
optimizer
,
group
=
process_group
,
**
optimizer_params
)
#
s
anity checks
#
S
anity checks
assert
len
(
self
.
sharded_optimizer
.
param_to_rank
)
==
len
(
list
(
self
.
module
.
parameters
())
),
"number of params do not match"
...
...
@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module):
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
if
not
self
.
accumulate_grads
:
self
.
need_reduction
=
True
if
self
.
broadcast_buffers
and
len
(
list
(
self
.
module
.
buffers
()))
>
0
:
self
.
_sync_buffers
()
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
...
...
@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module):
"""
assert
self
.
module
.
training
,
"Cannot call reduce in eval"
def
reduce_params
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fix in the buffer. """
def
reduce_grads
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert
self
.
buffer
is
not
None
# Fill in the packed IO buffer
buffer
:
Tensor
=
cast
(
Tensor
,
self
.
buffer
)
nonzero_buffer
=
False
if
len
(
params
)
>
1
:
offset
=
0
for
p
in
params
:
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
# The type error could have been fixed in later
# version of pytorch. Same elsewhere.
buffer
[
offset
:
offset
+
sz
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
# type: ignore
nonzero_buffer
=
True
else
:
buffer
[
offset
:
offset
+
sz
].
zero_
()
offset
+=
sz
else
:
# we only have a single grad to reduce
p
=
params
[
0
]
if
p
.
grad
is
not
None
:
buffer
=
p
.
grad
.
data
nonzero_buffer
=
True
elif
p
.
numel
()
<=
self
.
buffer
.
numel
():
buffer
=
buffer
[:
p
.
numel
()]
buffer
.
zero_
()
else
:
buffer
=
torch
.
zeros_like
(
p
)
buffer
=
params
[
0
].
grad
.
data
# type: ignore
if
nonzero_buffer
:
# Reduce
buffer
.
div_
(
self
.
world_size
)
# type: ignore
dist
.
reduce
(
tensor
=
buffer
,
dst
=
params_rank
,
group
=
self
.
process_group
)
# type: ignore
dist
.
reduce
(
buffer
,
params_rank
,
group
=
self
.
process_group
)
# type: ignore
# Copy reduced grads back into their original place, or free corresponding memory
if
params_rank
==
self
.
rank
:
# copy reduced grads back into their original place
offset
=
0
for
p
in
params
:
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
copy_
(
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
))
# type: ignore
else
:
p
.
grad
=
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
).
clone
()
offset
+=
sz
else
:
# wipe the grads
for
p
in
params
:
p
.
grad
=
None
...
...
@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module):
if
sz
>
self
.
buffer
.
numel
():
# reduce big params directly
assert
param_rank
is
not
None
reduce_
pa
ra
m
s
([
param
],
cast
(
int
,
param_rank
))
reduce_
g
ra
d
s
([
param
],
cast
(
int
,
param_rank
))
else
:
# smaller params are packed together from the same device
# and same rank.
...
...
@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module):
last_param_rank
is
not
None
and
last_param_rank
!=
param_rank
):
assert
last_param_rank
is
not
None
reduce_
pa
ra
m
s
(
buffered_params
,
cast
(
int
,
last_param_rank
))
reduce_
g
ra
d
s
(
buffered_params
,
cast
(
int
,
last_param_rank
))
offset
=
0
buffered_params
.
clear
()
buffered_params
.
append
(
cast
(
Parameter
,
param
))
...
...
@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module):
if
len
(
buffered_params
)
>
0
:
assert
param_rank
is
not
None
reduce_
pa
ra
m
s
(
buffered_params
,
cast
(
int
,
param_rank
))
reduce_
g
ra
d
s
(
buffered_params
,
cast
(
int
,
param_rank
))
reduction_fn
()
def
_sync_buffers
(
self
)
->
None
:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
"""
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
map
(
lambda
x
:
dist
.
broadcast
(
x
,
self
.
authoritative_rank
,
self
.
process_group
,
async_op
=
True
),
self
.
module
.
buffers
(),
),
)
)
tests/nn/data_parallel/test_sharded_ddp.py
View file @
79ded821
...
...
@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if
device
==
torch
.
device
(
"cuda"
):
torch
.
cuda
.
set_device
(
rank
)
# Any model works. Add one different buffer per rank
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
)).
to
(
device
)
model
.
register_buffer
(
"test_buffer"
,
torch
.
ones
((
1
))
*
rank
)
model
.
to
(
device
)
ddp
=
ShardedDataParallel
(
module
=
model
,
optimizer
=
torch
.
optim
.
SGD
,
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
},
world_size
=
world_size
module
=
model
,
optimizer
=
torch
.
optim
.
SGD
,
optimizer_params
=
{
"lr"
:
0.01
,
"momentum"
:
0.99
},
world_size
=
world_size
,
broadcast_buffers
=
True
,
)
optimizer
=
ddp
.
optimizer
model
=
ddp
.
module
input_tensor
=
torch
.
rand
((
64
,
2
)).
to
(
device
)
output
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
...
...
@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if
param
.
requires_grad
:
assert
param
.
grad
.
abs
().
sum
().
item
()
>
0.0
,
"The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer
.
step
()
new_eval
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
# assert new_eval.item() < output.item()
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for
b
in
model
.
buffers
():
assert
b
.
cpu
().
item
()
==
0.0
def
run_test
(
backend
,
device
,
world_size
=
2
):
...
...
@@ -76,7 +83,7 @@ def run_eval_mode(_unused):
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
))
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
}
ddp
=
ShardedDataParallel
(
model
,
torch
.
optim
.
SGD
,
optimizer_params
,
1
)
ddp
=
ShardedDataParallel
(
model
,
torch
.
optim
.
SGD
,
optimizer_params
,
1
,
broadcast_buffers
=
False
)
optimizer
=
ddp
.
optimizer
ddp
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment