Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e141a93e
Unverified
Commit
e141a93e
authored
Mar 31, 2021
by
Siddharth Goyal
Committed by
GitHub
Mar 31, 2021
Browse files
[feat] experimental: Add xpipe support (#553)
parent
204392e5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
15 deletions
+261
-15
benchmarks/experimental/experimental_async_approaches.py
benchmarks/experimental/experimental_async_approaches.py
+235
-4
fairscale/experimental/nn/ampnet_pipe/ampnet.py
fairscale/experimental/nn/ampnet_pipe/ampnet.py
+25
-11
fairscale/experimental/nn/ampnet_pipe/pipe.py
fairscale/experimental/nn/ampnet_pipe/pipe.py
+1
-0
No files found.
benchmarks/experimental/experimental_a
mpnet
.py
→
benchmarks/experimental/experimental_a
sync_approaches
.py
View file @
e141a93e
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
logging
...
...
@@ -203,7 +206,8 @@ class SpectrainSGDMomentum(Optimizer):
def
modify_current_params_using_reference_params
(
self
):
self
.
copy_params
(
self
.
reference_params
,
self
.
cur_params
)
def
update_weight_using_future_predictions
(
self
,
model_index
,
num_gpus
,
forward
):
# chunk_index and chunks parameters are for unused for spectrain usecase
def
update_weight_using_future_predictions
(
self
,
model_index
,
num_gpus
,
chunk_index
,
chunks
,
forward
):
if
forward
:
# In forward pass:
...
...
@@ -260,6 +264,226 @@ class SpectrainSGDMomentum(Optimizer):
return
loss
class
XpipeAdam
(
Optimizer
):
r
"""Implements Xpipe approach on top of Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
The implementation of the L2 penalty follows changes proposed in
`Decoupled Weight Decay Regularization`_.
Xpipe details can be found here: https://arxiv.org/abs/1911.04610
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 0: {}"
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 1: {}"
.
format
(
betas
[
1
]))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
)
params
=
list
(
params
)
super
(
XpipeAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
cur_params
,
self
.
master_params
=
self
.
prep_param_copies
(
params
)
_
,
self
.
forward_params
=
self
.
prep_param_copies
(
params
)
_
,
self
.
backward_params
=
self
.
prep_param_copies
(
params
)
for
group
in
self
.
param_groups
:
for
p
in
group
[
"params"
]:
param_state
=
self
.
state
[
p
]
param_state
[
"step"
]
=
0
# Exponential moving average of gradient values
param_state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
param_state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p
.
data
)
def
__setstate__
(
self
,
state
):
super
(
Adam
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
group
.
setdefault
(
"amsgrad"
,
False
)
def
prep_param_copies
(
self
,
params
):
model_params
=
[
param
for
param
in
params
if
param
.
requires_grad
]
reference_params
=
[
param
.
clone
().
detach
()
for
param
in
model_params
]
for
param
in
reference_params
:
param
.
requires_grad
=
True
return
model_params
,
reference_params
def
copy_params
(
self
,
master_params
,
model_params
):
for
model
,
master
in
zip
(
model_params
,
master_params
):
model
.
data
.
copy_
(
master
.
data
)
def
update_weight_using_future_predictions
(
self
,
model_index
,
num_gpus
,
current_microbatch_index
,
microbatches_per_minibatch
,
forward
):
if
forward
:
# Forward pass overview:
# if bell-weather:
# 1. read from master copy
# 2. predict and modify
# 3. flush updates to forward copy
# else:
# 1. read from forward copy
if
current_microbatch_index
%
microbatches_per_minibatch
==
0
:
# read from master copy
self
.
copy_params
(
self
.
master_params
,
self
.
cur_params
)
microbatch_index
=
current_microbatch_index
+
1
# predict and modify
for
group
in
self
.
param_groups
:
multiplier
=
group
[
"lr"
]
*
round
(
(
microbatch_index
+
num_gpus
-
model_index
/
2
-
2
)
/
microbatch_index
)
beta1
,
beta2
=
group
[
"betas"
]
eps
=
group
[
"eps"
]
for
p
in
group
[
"params"
]:
param_state
=
self
.
state
[
p
]
temp1
=
param_state
[
"exp_avg"
].
data
/
(
1
-
beta1
)
temp2
=
((
param_state
[
"exp_avg_sq"
].
data
/
(
1
-
beta2
))
+
eps
).
sqrt
()
p
.
data
.
addcdiv_
(
temp1
,
temp2
,
value
=-
multiplier
)
# flush updates to forward copy
self
.
copy_params
(
self
.
cur_params
,
self
.
forward_params
)
else
:
self
.
copy_params
(
self
.
forward_params
,
self
.
cur_params
)
else
:
# Backward pass overview:
# if bell-weather:
# 1. read from master copy
# 2. predict and modify
# 3. flush updates to backward copy
# else:
# 1. read from backward copy
if
current_microbatch_index
%
microbatches_per_minibatch
==
0
:
# read from master copy
self
.
copy_params
(
self
.
master_params
,
self
.
cur_params
)
microbatch_index
=
current_microbatch_index
+
1
# predict and modify
for
group
in
self
.
param_groups
:
multiplier
=
group
[
"lr"
]
*
(
microbatch_index
+
model_index
//
2
-
1
)
//
microbatch_index
beta1
,
beta2
=
group
[
"betas"
]
eps
=
group
[
"eps"
]
for
p
in
group
[
"params"
]:
param_state
=
self
.
state
[
p
]
temp1
=
param_state
[
"exp_avg"
].
data
/
(
1
-
beta1
)
temp2
=
((
param_state
[
"exp_avg_sq"
].
data
/
(
1
-
beta2
))
+
eps
).
sqrt
()
p
.
data
.
addcdiv_
(
temp1
,
temp2
,
value
=-
multiplier
)
# flush updates to forward copy
self
.
copy_params
(
self
.
cur_params
,
self
.
backward_params
)
else
:
self
.
copy_params
(
self
.
backward_params
,
self
.
cur_params
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
with
torch
.
enable_grad
():
loss
=
closure
()
for
group
in
self
.
param_groups
:
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
amsgrad
=
group
.
get
(
"amsgrad"
,
False
)
p_data
=
p
.
data
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
"step"
]
=
0
# Exponential moving average of gradient values
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data
)
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data
)
else
:
state
[
"exp_avg"
]
=
state
[
"exp_avg"
].
to
(
p_data
)
state
[
"exp_avg_sq"
]
=
state
[
"exp_avg_sq"
].
to
(
p_data
)
if
amsgrad
:
state
[
"max_exp_avg_sq"
]
=
state
[
"max_exp_avg_sq"
].
to
(
p_data
)
exp_avg
,
exp_avg_sq
=
state
[
"exp_avg"
],
state
[
"exp_avg_sq"
]
if
amsgrad
:
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
beta1
,
beta2
=
group
[
"betas"
]
state
[
"step"
]
+=
1
exp_avg_data
=
exp_avg
.
data
exp_avg_sq_data
=
exp_avg_sq
.
data
# Decay the first and second moment running average coefficient
exp_avg_data
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq_data
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
if
amsgrad
:
# Maintains the maximum of all 2nd moment running avg. till now
torch
.
max
(
max_exp_avg_sq
,
exp_avg_sq_data
,
out
=
max_exp_avg_sq_data
)
# Use the max. for normalizing running avg. of gradient
denom
=
max_exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
else
:
denom
=
exp_avg_sq_data
.
sqrt
().
add_
(
group
[
"eps"
])
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
if
group
[
"weight_decay"
]
!=
0
:
p_data
.
add_
(
p_data
,
alpha
=-
group
[
"weight_decay"
]
*
group
[
"lr"
])
p_data
.
addcdiv_
(
exp_avg_data
,
denom
,
value
=-
step_size
)
return
loss
def
get_data
(
device
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
fjldska
:
TEXT
=
torchtext
.
data
.
Field
(
...
...
@@ -321,7 +545,9 @@ def make_model(args, device, ntokens):
return
Adam
(
model
.
parameters
(),
lr
=
lr
)
def
make_custom_optimizer
(
model
,
args
):
if
args
.
spectrain
:
if
args
.
xpipe
:
return
XpipeAdam
(
model
.
parameters
(),
lr
=
lr
)
elif
args
.
spectrain
:
return
SpectrainSGDMomentum
(
model
.
parameters
(),
lr
=
lr
)
else
:
return
MySGD
(
model
.
parameters
(),
lr
=
lr
)
...
...
@@ -398,7 +624,9 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
optimizer
=
optimizer
(
model
,
args
)
transform_and_log
=
AsyncDelegate
(
vocab_size
)
model
.
interleave
(
lm_dataloader
,
criterion
,
optimizer
,
transform_and_log
,
args
.
min_update_interval
,
args
.
spectrain
)
model
.
interleave
(
lm_dataloader
,
criterion
,
optimizer
,
transform_and_log
,
args
.
min_update_interval
,
args
.
spectrain
or
args
.
xpipe
)
if
model
.
group
.
rank
()
==
model
.
group
.
size
()
-
1
:
print
(
"Done with an epoch"
)
...
...
@@ -615,6 +843,7 @@ parser.add_argument("--max-batch", type=int, default=4, help="Max number of batc
parser
.
add_argument
(
"--socket-name"
,
type
=
str
,
default
=
None
,
help
=
"socket ifname for gloo/tp"
)
parser
.
add_argument
(
"--num-decoder-layers"
,
type
=
int
,
default
=
10
,
help
=
"Number of decoder layers in the model"
)
parser
.
add_argument
(
"--spectrain"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use spectrain based weight prediction"
)
parser
.
add_argument
(
"--xpipe"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use xpipe based weight prediction"
)
parser
.
add_argument
(
"--lazy-construction"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Number of decoder layers in the model"
)
...
...
@@ -627,7 +856,9 @@ parser.add_argument("--min-update-interval", type=int, default=1, help="min upda
To run the script,
1. please build a suitable version of OpenMPI with a cuda-enabled UCX backend.
2. For running on 2 gpus:
<open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental_ampnet.py --num-decoder-layers=8 --host localhost --batch-size 4
<open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental/experimental_async_approaches.py --num-decoder-layers=8 --host localhost --batch-size 4
3. For doing Spectrain based weight prediction, add `--spectrain` to the training command line argument.
4. For doing Xpipe based weight prediction, add `--xpipe` to the training command line argument.
"""
if
__name__
==
"__main__"
:
...
...
fairscale/experimental/nn/ampnet_pipe/ampnet.py
View file @
e141a93e
...
...
@@ -73,6 +73,7 @@ class AsyncAMPnetEventLoop:
weight_prediction
:
bool
,
checkpoint_stop
:
int
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
],
chunks
:
int
,
):
self
.
partitions
=
partitions
self
.
group
=
group
...
...
@@ -81,9 +82,14 @@ class AsyncAMPnetEventLoop:
self
.
weight_prediction
=
weight_prediction
self
.
checkpoint_stop
=
checkpoint_stop
self
.
input_device
=
input_device
self
.
chunks
=
chunks
def
perform_optimizer_step
(
self
,
optimizer
:
Any
,
num_gradients
:
Any
)
->
Any
:
return
(
optimizer
is
not
None
)
and
((
num_gradients
%
self
.
min_update_interval
==
0
)
or
self
.
weight_prediction
)
return
(
(
optimizer
is
not
None
)
and
(
not
self
.
weight_prediction
and
num_gradients
%
self
.
min_update_interval
==
0
)
or
(
self
.
weight_prediction
and
num_gradients
%
self
.
chunks
==
0
)
)
def
async_send_inner
(
self
,
batch
:
Batch
,
index
:
int
)
->
Tuple
[
Batch
,
PipeMessage
]:
task
=
create_task_without_skip_trackers
(
...
...
@@ -160,7 +166,7 @@ class AsyncAMPnetEventLoop:
reqd_input
=
transform_logger_object
.
transform_input
(
cur_batch
).
to
(
self
.
input_device
)
batch
=
Batch
(
reqd_input
,
count
)
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
count
,
self
.
chunks
,
forward
=
True
)
activations
[
count
],
message
=
self
.
async_send_inner
(
batch
,
count
)
self
.
transport
.
send_message
(
message
,
sync
=
True
)
count
+=
1
...
...
@@ -177,7 +183,7 @@ class AsyncAMPnetEventLoop:
reqd_input
=
transform_logger_object
.
transform_input
(
cur_batch
).
to
(
self
.
input_device
)
batch
=
Batch
(
reqd_input
,
count
)
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
count
,
self
.
chunks
,
forward
=
True
)
activations
[
count
],
forward_message
=
self
.
async_send_inner
(
batch
,
count
)
count
+=
1
...
...
@@ -186,7 +192,9 @@ class AsyncAMPnetEventLoop:
args
:
AsyncMessageBody
=
message
.
args
assert
args
.
message_type
is
AsyncMessageType
.
Gradients
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
async_grad_inner
(
message
,
activations
)
# Send after grad
...
...
@@ -208,7 +216,7 @@ class AsyncAMPnetEventLoop:
args
=
message
.
args
assert
args
.
message_type
is
AsyncMessageType
.
Gradients
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
async_grad_inner
(
message
,
activations
)
num_gradients
+=
1
...
...
@@ -248,7 +256,7 @@ class AsyncAMPnetEventLoop:
batch
=
self
.
get_batch_from_message
(
message
,
EVENT_LOOP_GRADIENTS_QUEUE
)
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
count
,
self
.
chunks
,
forward
=
True
)
task
=
create_task_without_skip_trackers
(
self
.
checkpoint_stop
,
args
.
microbatch_index
,
self
.
group
.
rank
(),
batch
,
self
.
partitions
[
0
].
module
,
)
...
...
@@ -257,7 +265,9 @@ class AsyncAMPnetEventLoop:
task
.
finalize
(
output
)
# one backward
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
output_tensor
=
transform_logger_object
.
transform_output_before_loss
(
output
.
tensor
)
loss
=
criterion
(
output_tensor
,
reqd_target
)
...
...
@@ -307,7 +317,9 @@ class AsyncAMPnetEventLoop:
n_warmup
=
ranks
[
-
1
]
-
cur_rank
for
_
in
range
(
n_warmup
):
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_activations
,
self
.
chunks
,
forward
=
True
)
message
=
self
.
event_loop_trunk_forward_helper
(
activations
)
self
.
transport
.
send_message
(
message
,
sync
=
True
)
num_activations
+=
1
...
...
@@ -316,13 +328,15 @@ class AsyncAMPnetEventLoop:
while
num_activations
<
num_microbatch
:
# 1 Forward
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_activations
,
self
.
chunks
,
forward
=
True
)
message
=
self
.
event_loop_trunk_forward_helper
(
activations
)
num_activations
+=
1
# 1 Backward
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
event_loop_trunk_backward_helper
(
activations
)
num_gradients
+=
1
if
self
.
perform_optimizer_step
(
optimizer
,
num_gradients
):
...
...
@@ -336,7 +350,7 @@ class AsyncAMPnetEventLoop:
remaining
=
len
(
activations
)
for
_
in
range
(
remaining
):
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
event_loop_trunk_backward_helper
(
activations
)
num_gradients
+=
1
if
self
.
perform_optimizer_step
(
optimizer
,
num_gradients
):
...
...
fairscale/experimental/nn/ampnet_pipe/pipe.py
View file @
e141a93e
...
...
@@ -56,6 +56,7 @@ class AMPnetPipe(AsyncPipe):
weight_prediction
,
checkpoint_stop
,
self
.
input_device
,
self
.
chunks
,
)
if
rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment