Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
79ded821
Unverified
Commit
79ded821
authored
Sep 28, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Sep 28, 2020
Browse files
[ShardedDDP] Sync buffers + small cleanup (#112)
- adding the buffer broadcast option - minor cleanup in shardedDDP
parent
41819af9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
42 deletions
+59
-42
benchmarks/oss.py
benchmarks/oss.py
+6
-2
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+40
-34
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+13
-6
No files found.
benchmarks/oss.py
View file @
79ded821
...
...
@@ -62,7 +62,7 @@ def train(
):
assert
not
use_sdp
or
(
use_sdp
and
use_oss
),
"ShardedDataParallel requires OSS"
# DDP
dist_init
(
rank
,
world_size
,
backend
)
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
)
# Setup
torch
.
cuda
.
set_device
(
rank
)
...
...
@@ -81,7 +81,11 @@ def train(
if
use_sdp
:
ddp
=
ShardedDataParallel
(
module
=
model
,
optimizer
=
OPTIM
,
optimizer_params
=
{
"lr"
:
1e-4
,
"momentum"
:
0.9
},
world_size
=
world_size
,
module
=
model
,
optimizer
=
OPTIM
,
optimizer_params
=
{
"lr"
:
1e-4
,
"momentum"
:
0.9
},
world_size
=
world_size
,
broadcast_buffers
=
False
,
)
ddp
.
train
()
optimizer
=
ddp
.
optimizer
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
79ded821
...
...
@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and
does not
This version uses a c10d process group for communication and
optionally
broadcast buffers.
Args:
...
...
@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module):
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
will be used.
...
...
@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module):
optimizer
:
Type
[
torch
.
optim
.
Optimizer
],
optimizer_params
:
Dict
[
str
,
Any
],
world_size
:
int
,
broadcast_buffers
:
bool
,
process_group
:
Any
=
None
,
buffer_size
:
int
=
2
**
28
,
):
...
...
@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module):
self
.
world_size
=
world_size
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
broadcast_buffers
=
broadcast_buffers
self
.
authoritative_rank
=
0
# Never use a bigger buffer than the number of model params
self
.
buffer_size
=
min
(
buffer_size
,
sum
(
p
.
numel
()
for
p
in
self
.
module
.
parameters
()))
...
...
@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer
self
.
sharded_optimizer
=
OSS
(
self
.
module
.
parameters
(),
optim
=
optimizer
,
group
=
process_group
,
**
optimizer_params
)
#
s
anity checks
#
S
anity checks
assert
len
(
self
.
sharded_optimizer
.
param_to_rank
)
==
len
(
list
(
self
.
module
.
parameters
())
),
"number of params do not match"
...
...
@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module):
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
if
not
self
.
accumulate_grads
:
self
.
need_reduction
=
True
if
self
.
broadcast_buffers
and
len
(
list
(
self
.
module
.
buffers
()))
>
0
:
self
.
_sync_buffers
()
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
...
...
@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module):
"""
assert
self
.
module
.
training
,
"Cannot call reduce in eval"
def
reduce_params
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fix in the buffer. """
def
reduce_grads
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert
self
.
buffer
is
not
None
# Fill in the packed IO buffer
buffer
:
Tensor
=
cast
(
Tensor
,
self
.
buffer
)
nonzero_buffer
=
False
if
len
(
params
)
>
1
:
offset
=
0
for
p
in
params
:
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
# The type error could have been fixed in later
# version of pytorch. Same elsewhere.
buffer
[
offset
:
offset
+
sz
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
# type: ignore
nonzero_buffer
=
True
else
:
buffer
[
offset
:
offset
+
sz
].
zero_
()
buffer
[
offset
:
offset
+
sz
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
# type: ignore
offset
+=
sz
else
:
# we only have a single grad to reduce
p
=
params
[
0
]
if
p
.
grad
is
not
None
:
buffer
=
p
.
grad
.
data
nonzero_buffer
=
True
elif
p
.
numel
()
<=
self
.
buffer
.
numel
():
buffer
=
buffer
[:
p
.
numel
()]
buffer
.
zero_
()
else
:
buffer
=
torch
.
zeros_like
(
p
)
if
nonzero_buffer
:
buffer
.
div_
(
self
.
world_size
)
# type: ignore
buffer
=
params
[
0
].
grad
.
data
# type: ignore
dist
.
reduce
(
buffer
,
params_rank
,
group
=
self
.
process_group
)
# type: ignore
# Reduce
buffer
.
div_
(
self
.
world_size
)
# type: ignore
dist
.
reduce
(
tensor
=
buffer
,
dst
=
params_rank
,
group
=
self
.
process_group
)
# type: ignore
# Copy reduced grads back into their original place, or free corresponding memory
if
params_rank
==
self
.
rank
:
# copy reduced grads back into their original place
offset
=
0
for
p
in
params
:
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
copy_
(
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
))
# type: ignore
else
:
p
.
grad
=
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
).
clone
()
p
.
grad
.
data
.
copy_
(
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
))
# type: ignore
offset
+=
sz
else
:
# wipe the grads
for
p
in
params
:
p
.
grad
=
None
...
...
@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module):
if
sz
>
self
.
buffer
.
numel
():
# reduce big params directly
assert
param_rank
is
not
None
reduce_
pa
ra
m
s
([
param
],
cast
(
int
,
param_rank
))
reduce_
g
ra
d
s
([
param
],
cast
(
int
,
param_rank
))
else
:
# smaller params are packed together from the same device
# and same rank.
...
...
@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module):
last_param_rank
is
not
None
and
last_param_rank
!=
param_rank
):
assert
last_param_rank
is
not
None
reduce_
pa
ra
m
s
(
buffered_params
,
cast
(
int
,
last_param_rank
))
reduce_
g
ra
d
s
(
buffered_params
,
cast
(
int
,
last_param_rank
))
offset
=
0
buffered_params
.
clear
()
buffered_params
.
append
(
cast
(
Parameter
,
param
))
...
...
@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module):
if
len
(
buffered_params
)
>
0
:
assert
param_rank
is
not
None
reduce_
pa
ra
m
s
(
buffered_params
,
cast
(
int
,
param_rank
))
reduce_
g
ra
d
s
(
buffered_params
,
cast
(
int
,
param_rank
))
reduction_fn
()
def
_sync_buffers
(
self
)
->
None
:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
"""
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
map
(
lambda
x
:
dist
.
broadcast
(
x
,
self
.
authoritative_rank
,
self
.
process_group
,
async_op
=
True
),
self
.
module
.
buffers
(),
),
)
)
tests/nn/data_parallel/test_sharded_ddp.py
View file @
79ded821
...
...
@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if
device
==
torch
.
device
(
"cuda"
):
torch
.
cuda
.
set_device
(
rank
)
# Any model works. Add one different buffer per rank
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
)).
to
(
device
)
model
.
register_buffer
(
"test_buffer"
,
torch
.
ones
((
1
))
*
rank
)
model
.
to
(
device
)
ddp
=
ShardedDataParallel
(
module
=
model
,
optimizer
=
torch
.
optim
.
SGD
,
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
},
world_size
=
world_size
module
=
model
,
optimizer
=
torch
.
optim
.
SGD
,
optimizer_params
=
{
"lr"
:
0.01
,
"momentum"
:
0.99
},
world_size
=
world_size
,
broadcast_buffers
=
True
,
)
optimizer
=
ddp
.
optimizer
model
=
ddp
.
module
input_tensor
=
torch
.
rand
((
64
,
2
)).
to
(
device
)
output
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
...
...
@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if
param
.
requires_grad
:
assert
param
.
grad
.
abs
().
sum
().
item
()
>
0.0
,
"The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer
.
step
()
new_eval
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
# assert new_eval.item() < output.item()
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for
b
in
model
.
buffers
():
assert
b
.
cpu
().
item
()
==
0.0
def
run_test
(
backend
,
device
,
world_size
=
2
):
...
...
@@ -76,7 +83,7 @@ def run_eval_mode(_unused):
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
))
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
}
ddp
=
ShardedDataParallel
(
model
,
torch
.
optim
.
SGD
,
optimizer_params
,
1
)
ddp
=
ShardedDataParallel
(
model
,
torch
.
optim
.
SGD
,
optimizer_params
,
1
,
broadcast_buffers
=
False
)
optimizer
=
ddp
.
optimizer
ddp
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment