Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
79ded821
Unverified
Commit
79ded821
authored
Sep 28, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Sep 28, 2020
Browse files
[ShardedDDP] Sync buffers + small cleanup (#112)
- adding the buffer broadcast option - minor cleanup in shardedDDP
parent
41819af9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
42 deletions
+59
-42
benchmarks/oss.py
benchmarks/oss.py
+6
-2
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+40
-34
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+13
-6
No files found.
benchmarks/oss.py
View file @
79ded821
...
@@ -62,7 +62,7 @@ def train(
...
@@ -62,7 +62,7 @@ def train(
):
):
assert
not
use_sdp
or
(
use_sdp
and
use_oss
),
"ShardedDataParallel requires OSS"
assert
not
use_sdp
or
(
use_sdp
and
use_oss
),
"ShardedDataParallel requires OSS"
# DDP
# DDP
dist_init
(
rank
,
world_size
,
backend
)
dist_init
(
rank
=
rank
,
world_size
=
world_size
,
backend
=
backend
)
# Setup
# Setup
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
set_device
(
rank
)
...
@@ -81,7 +81,11 @@ def train(
...
@@ -81,7 +81,11 @@ def train(
if
use_sdp
:
if
use_sdp
:
ddp
=
ShardedDataParallel
(
ddp
=
ShardedDataParallel
(
module
=
model
,
optimizer
=
OPTIM
,
optimizer_params
=
{
"lr"
:
1e-4
,
"momentum"
:
0.9
},
world_size
=
world_size
,
module
=
model
,
optimizer
=
OPTIM
,
optimizer_params
=
{
"lr"
:
1e-4
,
"momentum"
:
0.9
},
world_size
=
world_size
,
broadcast_buffers
=
False
,
)
)
ddp
.
train
()
ddp
.
train
()
optimizer
=
ddp
.
optimizer
optimizer
=
ddp
.
optimizer
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
79ded821
...
@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and
does not
This version uses a c10d process group for communication and
optionally
broadcast buffers.
broadcast buffers.
Args:
Args:
...
@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module):
...
@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module):
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
distributed gradient reduction. If None, the default WORLD process group
will be used.
will be used.
...
@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module):
optimizer
:
Type
[
torch
.
optim
.
Optimizer
],
optimizer
:
Type
[
torch
.
optim
.
Optimizer
],
optimizer_params
:
Dict
[
str
,
Any
],
optimizer_params
:
Dict
[
str
,
Any
],
world_size
:
int
,
world_size
:
int
,
broadcast_buffers
:
bool
,
process_group
:
Any
=
None
,
process_group
:
Any
=
None
,
buffer_size
:
int
=
2
**
28
,
buffer_size
:
int
=
2
**
28
,
):
):
...
@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module):
...
@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module):
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
broadcast_buffers
=
broadcast_buffers
self
.
authoritative_rank
=
0
# Never use a bigger buffer than the number of model params
# Never use a bigger buffer than the number of model params
self
.
buffer_size
=
min
(
buffer_size
,
sum
(
p
.
numel
()
for
p
in
self
.
module
.
parameters
()))
self
.
buffer_size
=
min
(
buffer_size
,
sum
(
p
.
numel
()
for
p
in
self
.
module
.
parameters
()))
...
@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer
# Build the sharded optimizer
self
.
sharded_optimizer
=
OSS
(
self
.
module
.
parameters
(),
optim
=
optimizer
,
group
=
process_group
,
**
optimizer_params
)
self
.
sharded_optimizer
=
OSS
(
self
.
module
.
parameters
(),
optim
=
optimizer
,
group
=
process_group
,
**
optimizer_params
)
#
s
anity checks
#
S
anity checks
assert
len
(
self
.
sharded_optimizer
.
param_to_rank
)
==
len
(
assert
len
(
self
.
sharded_optimizer
.
param_to_rank
)
==
len
(
list
(
self
.
module
.
parameters
())
list
(
self
.
module
.
parameters
())
),
"number of params do not match"
),
"number of params do not match"
...
@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module):
...
@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module):
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
if
not
self
.
accumulate_grads
:
if
not
self
.
accumulate_grads
:
self
.
need_reduction
=
True
self
.
need_reduction
=
True
if
self
.
broadcast_buffers
and
len
(
list
(
self
.
module
.
buffers
()))
>
0
:
self
.
_sync_buffers
()
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
def
reduce
(
self
)
->
None
:
...
@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module):
...
@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module):
"""
"""
assert
self
.
module
.
training
,
"Cannot call reduce in eval"
assert
self
.
module
.
training
,
"Cannot call reduce in eval"
def
reduce_params
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
def
reduce_grads
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fix in the buffer. """
""" Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert
self
.
buffer
is
not
None
assert
self
.
buffer
is
not
None
# Fill in the packed IO buffer
buffer
:
Tensor
=
cast
(
Tensor
,
self
.
buffer
)
buffer
:
Tensor
=
cast
(
Tensor
,
self
.
buffer
)
nonzero_buffer
=
False
if
len
(
params
)
>
1
:
if
len
(
params
)
>
1
:
offset
=
0
offset
=
0
for
p
in
params
:
for
p
in
params
:
sz
=
p
.
numel
()
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
buffer
[
offset
:
offset
+
sz
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
# type: ignore
# The type error could have been fixed in later
# version of pytorch. Same elsewhere.
buffer
[
offset
:
offset
+
sz
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
# type: ignore
nonzero_buffer
=
True
else
:
buffer
[
offset
:
offset
+
sz
].
zero_
()
offset
+=
sz
offset
+=
sz
else
:
else
:
# we only have a single grad to reduce
# we only have a single grad to reduce
p
=
params
[
0
]
buffer
=
params
[
0
].
grad
.
data
# type: ignore
if
p
.
grad
is
not
None
:
buffer
=
p
.
grad
.
data
nonzero_buffer
=
True
elif
p
.
numel
()
<=
self
.
buffer
.
numel
():
buffer
=
buffer
[:
p
.
numel
()]
buffer
.
zero_
()
else
:
buffer
=
torch
.
zeros_like
(
p
)
if
nonzero_buffer
:
buffer
.
div_
(
self
.
world_size
)
# type: ignore
dist
.
reduce
(
buffer
,
params_rank
,
group
=
self
.
process_group
)
# type: ignore
# Reduce
buffer
.
div_
(
self
.
world_size
)
# type: ignore
dist
.
reduce
(
tensor
=
buffer
,
dst
=
params_rank
,
group
=
self
.
process_group
)
# type: ignore
# Copy reduced grads back into their original place, or free corresponding memory
if
params_rank
==
self
.
rank
:
if
params_rank
==
self
.
rank
:
# copy reduced grads back into their original place
offset
=
0
offset
=
0
for
p
in
params
:
for
p
in
params
:
sz
=
p
.
numel
()
sz
=
p
.
numel
()
if
p
.
grad
is
not
None
:
p
.
grad
.
data
.
copy_
(
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
))
# type: ignore
p
.
grad
.
data
.
copy_
(
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
))
# type: ignore
else
:
p
.
grad
=
buffer
[
offset
:
offset
+
sz
].
view_as
(
p
).
clone
()
offset
+=
sz
offset
+=
sz
else
:
else
:
# wipe the grads
for
p
in
params
:
for
p
in
params
:
p
.
grad
=
None
p
.
grad
=
None
...
@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module):
if
sz
>
self
.
buffer
.
numel
():
if
sz
>
self
.
buffer
.
numel
():
# reduce big params directly
# reduce big params directly
assert
param_rank
is
not
None
assert
param_rank
is
not
None
reduce_
pa
ra
m
s
([
param
],
cast
(
int
,
param_rank
))
reduce_
g
ra
d
s
([
param
],
cast
(
int
,
param_rank
))
else
:
else
:
# smaller params are packed together from the same device
# smaller params are packed together from the same device
# and same rank.
# and same rank.
...
@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module):
last_param_rank
is
not
None
and
last_param_rank
!=
param_rank
last_param_rank
is
not
None
and
last_param_rank
!=
param_rank
):
):
assert
last_param_rank
is
not
None
assert
last_param_rank
is
not
None
reduce_
pa
ra
m
s
(
buffered_params
,
cast
(
int
,
last_param_rank
))
reduce_
g
ra
d
s
(
buffered_params
,
cast
(
int
,
last_param_rank
))
offset
=
0
offset
=
0
buffered_params
.
clear
()
buffered_params
.
clear
()
buffered_params
.
append
(
cast
(
Parameter
,
param
))
buffered_params
.
append
(
cast
(
Parameter
,
param
))
...
@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module):
...
@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module):
if
len
(
buffered_params
)
>
0
:
if
len
(
buffered_params
)
>
0
:
assert
param_rank
is
not
None
assert
param_rank
is
not
None
reduce_
pa
ra
m
s
(
buffered_params
,
cast
(
int
,
param_rank
))
reduce_
g
ra
d
s
(
buffered_params
,
cast
(
int
,
param_rank
))
reduction_fn
()
reduction_fn
()
def
_sync_buffers
(
self
)
->
None
:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
"""
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
map
(
lambda
x
:
dist
.
broadcast
(
x
,
self
.
authoritative_rank
,
self
.
process_group
,
async_op
=
True
),
self
.
module
.
buffers
(),
),
)
)
tests/nn/data_parallel/test_sharded_ddp.py
View file @
79ded821
...
@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
...
@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if
device
==
torch
.
device
(
"cuda"
):
if
device
==
torch
.
device
(
"cuda"
):
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
set_device
(
rank
)
# Any model works. Add one different buffer per rank
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
)).
to
(
device
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
)).
to
(
device
)
model
.
register_buffer
(
"test_buffer"
,
torch
.
ones
((
1
))
*
rank
)
model
.
to
(
device
)
ddp
=
ShardedDataParallel
(
ddp
=
ShardedDataParallel
(
module
=
model
,
optimizer
=
torch
.
optim
.
SGD
,
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
},
world_size
=
world_size
module
=
model
,
optimizer
=
torch
.
optim
.
SGD
,
optimizer_params
=
{
"lr"
:
0.01
,
"momentum"
:
0.99
},
world_size
=
world_size
,
broadcast_buffers
=
True
,
)
)
optimizer
=
ddp
.
optimizer
optimizer
=
ddp
.
optimizer
model
=
ddp
.
module
input_tensor
=
torch
.
rand
((
64
,
2
)).
to
(
device
)
input_tensor
=
torch
.
rand
((
64
,
2
)).
to
(
device
)
output
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
output
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
...
@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
...
@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
assert
param
.
grad
.
abs
().
sum
().
item
()
>
0.0
,
"The reduce step should have populated all the gradients"
assert
param
.
grad
.
abs
().
sum
().
item
()
>
0.0
,
"The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data)
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
optimizer
.
step
()
for
b
in
model
.
buffers
():
new_eval
=
ddp
(
input_tensor
).
abs
().
sum
()
/
input_tensor
.
numel
()
assert
b
.
cpu
().
item
()
==
0.0
# assert new_eval.item() < output.item()
def
run_test
(
backend
,
device
,
world_size
=
2
):
def
run_test
(
backend
,
device
,
world_size
=
2
):
...
@@ -76,7 +83,7 @@ def run_eval_mode(_unused):
...
@@ -76,7 +83,7 @@ def run_eval_mode(_unused):
)
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
))
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
))
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
}
optimizer_params
=
{
"lr"
:
0.1
,
"momentum"
:
0.99
}
ddp
=
ShardedDataParallel
(
model
,
torch
.
optim
.
SGD
,
optimizer_params
,
1
)
ddp
=
ShardedDataParallel
(
model
,
torch
.
optim
.
SGD
,
optimizer_params
,
1
,
broadcast_buffers
=
False
)
optimizer
=
ddp
.
optimizer
optimizer
=
ddp
.
optimizer
ddp
.
eval
()
ddp
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment