Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e141a93e
Unverified
Commit
e141a93e
authored
Mar 31, 2021
by
Siddharth Goyal
Committed by
GitHub
Mar 31, 2021
Browse files
[feat] experimental: Add xpipe support (#553)
parent
204392e5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
15 deletions
+261
-15
benchmarks/experimental/experimental_async_approaches.py
benchmarks/experimental/experimental_async_approaches.py
+235
-4
fairscale/experimental/nn/ampnet_pipe/ampnet.py
fairscale/experimental/nn/ampnet_pipe/ampnet.py
+25
-11
fairscale/experimental/nn/ampnet_pipe/pipe.py
fairscale/experimental/nn/ampnet_pipe/pipe.py
+1
-0
No files found.
benchmarks/experimental/experimental_a
mpnet
.py
→
benchmarks/experimental/experimental_a
sync_approaches
.py
View file @
e141a93e
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
argparse
import
logging
import
logging
...
@@ -203,7 +206,8 @@ class SpectrainSGDMomentum(Optimizer):
...
@@ -203,7 +206,8 @@ class SpectrainSGDMomentum(Optimizer):
def
modify_current_params_using_reference_params
(
self
):
def
modify_current_params_using_reference_params
(
self
):
self
.
copy_params
(
self
.
reference_params
,
self
.
cur_params
)
self
.
copy_params
(
self
.
reference_params
,
self
.
cur_params
)
def
update_weight_using_future_predictions
(
self
,
model_index
,
num_gpus
,
forward
):
# chunk_index and chunks parameters are for unused for spectrain usecase
def
update_weight_using_future_predictions
(
self
,
model_index
,
num_gpus
,
chunk_index
,
chunks
,
forward
):
if
forward
:
if
forward
:
# In forward pass:
# In forward pass:
...
@@ -260,6 +264,226 @@ class SpectrainSGDMomentum(Optimizer):
...
@@ -260,6 +264,226 @@ class SpectrainSGDMomentum(Optimizer):
return
loss
return
loss
class
XpipeAdam
(
Optimizer
):
r
"""Implements Xpipe approach on top of Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
The implementation of the L2 penalty follows changes proposed in
`Decoupled Weight Decay Regularization`_.
Xpipe details can be found here: https://arxiv.org/abs/1911.04610
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
):
if
not
0.0
<=
lr
:
raise
ValueError
(
"Invalid learning rate: {}"
.
format
(
lr
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {}"
.
format
(
eps
))
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 0: {}"
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter at index 1: {}"
.
format
(
betas
[
1
]))
if
not
0.0
<=
weight_decay
:
raise
ValueError
(
"Invalid weight_decay value: {}"
.
format
(
weight_decay
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
)
params
=
list
(
params
)
super
(
XpipeAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
cur_params
,
self
.
master_params
=
self
.
prep_param_copies
(
params
)
_
,
self
.
forward_params
=
self
.
prep_param_copies
(
params
)
_
,
self
.
backward_params
=
self
.
prep_param_copies
(
params
)
for
group
in
self
.
param_groups
:
for
p
in
group
[
"params"
]:
param_state
=
self
.
state
[
p
]
param_state
[
"step"
]
=
0
# Exponential moving average of gradient values
param_state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
param_state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p
.
data
)
def
__setstate__
(
self
,
state
):
super
(
Adam
,
self
).
__setstate__
(
state
)
for
group
in
self
.
param_groups
:
group
.
setdefault
(
"amsgrad"
,
False
)
def
prep_param_copies
(
self
,
params
):
model_params
=
[
param
for
param
in
params
if
param
.
requires_grad
]
reference_params
=
[
param
.
clone
().
detach
()
for
param
in
model_params
]
for
param
in
reference_params
:
param
.
requires_grad
=
True
return
model_params
,
reference_params
def
copy_params
(
self
,
master_params
,
model_params
):
for
model
,
master
in
zip
(
model_params
,
master_params
):
model
.
data
.
copy_
(
master
.
data
)
def
update_weight_using_future_predictions
(
self
,
model_index
,
num_gpus
,
current_microbatch_index
,
microbatches_per_minibatch
,
forward
):
if
forward
:
# Forward pass overview:
# if bell-weather:
# 1. read from master copy
# 2. predict and modify
# 3. flush updates to forward copy
# else:
# 1. read from forward copy
if
current_microbatch_index
%
microbatches_per_minibatch
==
0
:
# read from master copy
self
.
copy_params
(
self
.
master_params
,
self
.
cur_params
)
microbatch_index
=
current_microbatch_index
+
1
# predict and modify
for
group
in
self
.
param_groups
:
multiplier
=
group
[
"lr"
]
*
round
(
(
microbatch_index
+
num_gpus
-
model_index
/
2
-
2
)
/
microbatch_index
)
beta1
,
beta2
=
group
[
"betas"
]
eps
=
group
[
"eps"
]
for
p
in
group
[
"params"
]:
param_state
=
self
.
state
[
p
]
temp1
=
param_state
[
"exp_avg"
].
data
/
(
1
-
beta1
)
temp2
=
((
param_state
[
"exp_avg_sq"
].
data
/
(
1
-
beta2
))
+
eps
).
sqrt
()
p
.
data
.
addcdiv_
(
temp1
,
temp2
,
value
=-
multiplier
)
# flush updates to forward copy
self
.
copy_params
(
self
.
cur_params
,
self
.
forward_params
)
else
:
self
.
copy_params
(
self
.
forward_params
,
self
.
cur_params
)
else
:
# Backward pass overview:
# if bell-weather:
# 1. read from master copy
# 2. predict and modify
# 3. flush updates to backward copy
# else:
# 1. read from backward copy
if
current_microbatch_index
%
microbatches_per_minibatch
==
0
:
# read from master copy
self
.
copy_params
(
self
.
master_params
,
self
.
cur_params
)
microbatch_index
=
current_microbatch_index
+
1
# predict and modify
for
group
in
self
.
param_groups
:
multiplier
=
group
[
"lr"
]
*
(
microbatch_index
+
model_index
//
2
-
1
)
//
microbatch_index
beta1
,
beta2
=
group
[
"betas"
]
eps
=
group
[
"eps"
]
for
p
in
group
[
"params"
]:
param_state
=
self
.
state
[
p
]
temp1
=
param_state
[
"exp_avg"
].
data
/
(
1
-
beta1
)
temp2
=
((
param_state
[
"exp_avg_sq"
].
data
/
(
1
-
beta2
))
+
eps
).
sqrt
()
p
.
data
.
addcdiv_
(
temp1
,
temp2
,
value
=-
multiplier
)
# flush updates to forward copy
self
.
copy_params
(
self
.
cur_params
,
self
.
backward_params
)
else
:
self
.
copy_params
(
self
.
backward_params
,
self
.
cur_params
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
with
torch
.
enable_grad
():
loss
=
closure
()
for
group
in
self
.
param_groups
:
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
amsgrad
=
group
.
get
(
"amsgrad"
,
False
)
p_data
=
p
.
data
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
"step"
]
=
0
# Exponential moving average of gradient values
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p_data
)
# Exponential moving average of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
"max_exp_avg_sq"
]
=
torch
.
zeros_like
(
p_data
)
else
:
state
[
"exp_avg"
]
=
state
[
"exp_avg"
].
to
(
p_data
)
state
[
"exp_avg_sq"
]
=
state
[
"exp_avg_sq"
].
to
(
p_data
)
if
amsgrad
:
state
[
"max_exp_avg_sq"
]
=
state
[
"max_exp_avg_sq"
].
to
(
p_data
)
exp_avg
,
exp_avg_sq
=
state
[
"exp_avg"
],
state
[
"exp_avg_sq"
]
if
amsgrad
:
max_exp_avg_sq
=
state
[
"max_exp_avg_sq"
]
beta1
,
beta2
=
group
[
"betas"
]
state
[
"step"
]
+=
1
exp_avg_data
=
exp_avg
.
data
exp_avg_sq_data
=
exp_avg_sq
.
data
# Decay the first and second moment running average coefficient
exp_avg_data
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq_data
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
if
amsgrad
:
# Maintains the maximum of all 2nd moment running avg. till now
torch
.
max
(
max_exp_avg_sq
,
exp_avg_sq_data
,
out
=
max_exp_avg_sq_data
)
# Use the max. for normalizing running avg. of gradient
denom
=
max_exp_avg_sq
.
sqrt
().
add_
(
group
[
"eps"
])
else
:
denom
=
exp_avg_sq_data
.
sqrt
().
add_
(
group
[
"eps"
])
bias_correction1
=
1
-
beta1
**
state
[
"step"
]
bias_correction2
=
1
-
beta2
**
state
[
"step"
]
step_size
=
group
[
"lr"
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
if
group
[
"weight_decay"
]
!=
0
:
p_data
.
add_
(
p_data
,
alpha
=-
group
[
"weight_decay"
]
*
group
[
"lr"
])
p_data
.
addcdiv_
(
exp_avg_data
,
denom
,
value
=-
step_size
)
return
loss
def
get_data
(
device
):
def
get_data
(
device
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
fjldska
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
fjldska
:
TEXT
=
torchtext
.
data
.
Field
(
TEXT
=
torchtext
.
data
.
Field
(
...
@@ -321,7 +545,9 @@ def make_model(args, device, ntokens):
...
@@ -321,7 +545,9 @@ def make_model(args, device, ntokens):
return
Adam
(
model
.
parameters
(),
lr
=
lr
)
return
Adam
(
model
.
parameters
(),
lr
=
lr
)
def
make_custom_optimizer
(
model
,
args
):
def
make_custom_optimizer
(
model
,
args
):
if
args
.
spectrain
:
if
args
.
xpipe
:
return
XpipeAdam
(
model
.
parameters
(),
lr
=
lr
)
elif
args
.
spectrain
:
return
SpectrainSGDMomentum
(
model
.
parameters
(),
lr
=
lr
)
return
SpectrainSGDMomentum
(
model
.
parameters
(),
lr
=
lr
)
else
:
else
:
return
MySGD
(
model
.
parameters
(),
lr
=
lr
)
return
MySGD
(
model
.
parameters
(),
lr
=
lr
)
...
@@ -398,7 +624,9 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
...
@@ -398,7 +624,9 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
optimizer
=
optimizer
(
model
,
args
)
optimizer
=
optimizer
(
model
,
args
)
transform_and_log
=
AsyncDelegate
(
vocab_size
)
transform_and_log
=
AsyncDelegate
(
vocab_size
)
model
.
interleave
(
lm_dataloader
,
criterion
,
optimizer
,
transform_and_log
,
args
.
min_update_interval
,
args
.
spectrain
)
model
.
interleave
(
lm_dataloader
,
criterion
,
optimizer
,
transform_and_log
,
args
.
min_update_interval
,
args
.
spectrain
or
args
.
xpipe
)
if
model
.
group
.
rank
()
==
model
.
group
.
size
()
-
1
:
if
model
.
group
.
rank
()
==
model
.
group
.
size
()
-
1
:
print
(
"Done with an epoch"
)
print
(
"Done with an epoch"
)
...
@@ -615,6 +843,7 @@ parser.add_argument("--max-batch", type=int, default=4, help="Max number of batc
...
@@ -615,6 +843,7 @@ parser.add_argument("--max-batch", type=int, default=4, help="Max number of batc
parser
.
add_argument
(
"--socket-name"
,
type
=
str
,
default
=
None
,
help
=
"socket ifname for gloo/tp"
)
parser
.
add_argument
(
"--socket-name"
,
type
=
str
,
default
=
None
,
help
=
"socket ifname for gloo/tp"
)
parser
.
add_argument
(
"--num-decoder-layers"
,
type
=
int
,
default
=
10
,
help
=
"Number of decoder layers in the model"
)
parser
.
add_argument
(
"--num-decoder-layers"
,
type
=
int
,
default
=
10
,
help
=
"Number of decoder layers in the model"
)
parser
.
add_argument
(
"--spectrain"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use spectrain based weight prediction"
)
parser
.
add_argument
(
"--spectrain"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use spectrain based weight prediction"
)
parser
.
add_argument
(
"--xpipe"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use xpipe based weight prediction"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--lazy-construction"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Number of decoder layers in the model"
"--lazy-construction"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Number of decoder layers in the model"
)
)
...
@@ -627,7 +856,9 @@ parser.add_argument("--min-update-interval", type=int, default=1, help="min upda
...
@@ -627,7 +856,9 @@ parser.add_argument("--min-update-interval", type=int, default=1, help="min upda
To run the script,
To run the script,
1. please build a suitable version of OpenMPI with a cuda-enabled UCX backend.
1. please build a suitable version of OpenMPI with a cuda-enabled UCX backend.
2. For running on 2 gpus:
2. For running on 2 gpus:
<open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental_ampnet.py --num-decoder-layers=8 --host localhost --batch-size 4
<open-mpi-installed-dir>/bin/mpirun --host localhost:8 -np 2 --map-by node --mca pml ucx -x UCX_TLS=rc,sm,cuda_ipc,cuda_copy -x PYTHONPATH=$PWD -x PATH=$PATH -x LD_LIBRARY_PATH=$LD_LIBRARY_PATH -x UCX_RNDV_SCHEME=put_zcopy -x UCX_MEMTYPE_CACHE=n python3 benchmarks/experimental/experimental_async_approaches.py --num-decoder-layers=8 --host localhost --batch-size 4
3. For doing Spectrain based weight prediction, add `--spectrain` to the training command line argument.
4. For doing Xpipe based weight prediction, add `--xpipe` to the training command line argument.
"""
"""
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
fairscale/experimental/nn/ampnet_pipe/ampnet.py
View file @
e141a93e
...
@@ -73,6 +73,7 @@ class AsyncAMPnetEventLoop:
...
@@ -73,6 +73,7 @@ class AsyncAMPnetEventLoop:
weight_prediction
:
bool
,
weight_prediction
:
bool
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
],
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
],
chunks
:
int
,
):
):
self
.
partitions
=
partitions
self
.
partitions
=
partitions
self
.
group
=
group
self
.
group
=
group
...
@@ -81,9 +82,14 @@ class AsyncAMPnetEventLoop:
...
@@ -81,9 +82,14 @@ class AsyncAMPnetEventLoop:
self
.
weight_prediction
=
weight_prediction
self
.
weight_prediction
=
weight_prediction
self
.
checkpoint_stop
=
checkpoint_stop
self
.
checkpoint_stop
=
checkpoint_stop
self
.
input_device
=
input_device
self
.
input_device
=
input_device
self
.
chunks
=
chunks
def
perform_optimizer_step
(
self
,
optimizer
:
Any
,
num_gradients
:
Any
)
->
Any
:
def
perform_optimizer_step
(
self
,
optimizer
:
Any
,
num_gradients
:
Any
)
->
Any
:
return
(
optimizer
is
not
None
)
and
((
num_gradients
%
self
.
min_update_interval
==
0
)
or
self
.
weight_prediction
)
return
(
(
optimizer
is
not
None
)
and
(
not
self
.
weight_prediction
and
num_gradients
%
self
.
min_update_interval
==
0
)
or
(
self
.
weight_prediction
and
num_gradients
%
self
.
chunks
==
0
)
)
def
async_send_inner
(
self
,
batch
:
Batch
,
index
:
int
)
->
Tuple
[
Batch
,
PipeMessage
]:
def
async_send_inner
(
self
,
batch
:
Batch
,
index
:
int
)
->
Tuple
[
Batch
,
PipeMessage
]:
task
=
create_task_without_skip_trackers
(
task
=
create_task_without_skip_trackers
(
...
@@ -160,7 +166,7 @@ class AsyncAMPnetEventLoop:
...
@@ -160,7 +166,7 @@ class AsyncAMPnetEventLoop:
reqd_input
=
transform_logger_object
.
transform_input
(
cur_batch
).
to
(
self
.
input_device
)
reqd_input
=
transform_logger_object
.
transform_input
(
cur_batch
).
to
(
self
.
input_device
)
batch
=
Batch
(
reqd_input
,
count
)
batch
=
Batch
(
reqd_input
,
count
)
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
count
,
self
.
chunks
,
forward
=
True
)
activations
[
count
],
message
=
self
.
async_send_inner
(
batch
,
count
)
activations
[
count
],
message
=
self
.
async_send_inner
(
batch
,
count
)
self
.
transport
.
send_message
(
message
,
sync
=
True
)
self
.
transport
.
send_message
(
message
,
sync
=
True
)
count
+=
1
count
+=
1
...
@@ -177,7 +183,7 @@ class AsyncAMPnetEventLoop:
...
@@ -177,7 +183,7 @@ class AsyncAMPnetEventLoop:
reqd_input
=
transform_logger_object
.
transform_input
(
cur_batch
).
to
(
self
.
input_device
)
reqd_input
=
transform_logger_object
.
transform_input
(
cur_batch
).
to
(
self
.
input_device
)
batch
=
Batch
(
reqd_input
,
count
)
batch
=
Batch
(
reqd_input
,
count
)
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
count
,
self
.
chunks
,
forward
=
True
)
activations
[
count
],
forward_message
=
self
.
async_send_inner
(
batch
,
count
)
activations
[
count
],
forward_message
=
self
.
async_send_inner
(
batch
,
count
)
count
+=
1
count
+=
1
...
@@ -186,7 +192,9 @@ class AsyncAMPnetEventLoop:
...
@@ -186,7 +192,9 @@ class AsyncAMPnetEventLoop:
args
:
AsyncMessageBody
=
message
.
args
args
:
AsyncMessageBody
=
message
.
args
assert
args
.
message_type
is
AsyncMessageType
.
Gradients
assert
args
.
message_type
is
AsyncMessageType
.
Gradients
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
async_grad_inner
(
message
,
activations
)
self
.
async_grad_inner
(
message
,
activations
)
# Send after grad
# Send after grad
...
@@ -208,7 +216,7 @@ class AsyncAMPnetEventLoop:
...
@@ -208,7 +216,7 @@ class AsyncAMPnetEventLoop:
args
=
message
.
args
args
=
message
.
args
assert
args
.
message_type
is
AsyncMessageType
.
Gradients
assert
args
.
message_type
is
AsyncMessageType
.
Gradients
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
async_grad_inner
(
message
,
activations
)
self
.
async_grad_inner
(
message
,
activations
)
num_gradients
+=
1
num_gradients
+=
1
...
@@ -248,7 +256,7 @@ class AsyncAMPnetEventLoop:
...
@@ -248,7 +256,7 @@ class AsyncAMPnetEventLoop:
batch
=
self
.
get_batch_from_message
(
message
,
EVENT_LOOP_GRADIENTS_QUEUE
)
batch
=
self
.
get_batch_from_message
(
message
,
EVENT_LOOP_GRADIENTS_QUEUE
)
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
count
,
self
.
chunks
,
forward
=
True
)
task
=
create_task_without_skip_trackers
(
task
=
create_task_without_skip_trackers
(
self
.
checkpoint_stop
,
args
.
microbatch_index
,
self
.
group
.
rank
(),
batch
,
self
.
partitions
[
0
].
module
,
self
.
checkpoint_stop
,
args
.
microbatch_index
,
self
.
group
.
rank
(),
batch
,
self
.
partitions
[
0
].
module
,
)
)
...
@@ -257,7 +265,9 @@ class AsyncAMPnetEventLoop:
...
@@ -257,7 +265,9 @@ class AsyncAMPnetEventLoop:
task
.
finalize
(
output
)
task
.
finalize
(
output
)
# one backward
# one backward
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
output_tensor
=
transform_logger_object
.
transform_output_before_loss
(
output
.
tensor
)
output_tensor
=
transform_logger_object
.
transform_output_before_loss
(
output
.
tensor
)
loss
=
criterion
(
output_tensor
,
reqd_target
)
loss
=
criterion
(
output_tensor
,
reqd_target
)
...
@@ -307,7 +317,9 @@ class AsyncAMPnetEventLoop:
...
@@ -307,7 +317,9 @@ class AsyncAMPnetEventLoop:
n_warmup
=
ranks
[
-
1
]
-
cur_rank
n_warmup
=
ranks
[
-
1
]
-
cur_rank
for
_
in
range
(
n_warmup
):
for
_
in
range
(
n_warmup
):
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_activations
,
self
.
chunks
,
forward
=
True
)
message
=
self
.
event_loop_trunk_forward_helper
(
activations
)
message
=
self
.
event_loop_trunk_forward_helper
(
activations
)
self
.
transport
.
send_message
(
message
,
sync
=
True
)
self
.
transport
.
send_message
(
message
,
sync
=
True
)
num_activations
+=
1
num_activations
+=
1
...
@@ -316,13 +328,15 @@ class AsyncAMPnetEventLoop:
...
@@ -316,13 +328,15 @@ class AsyncAMPnetEventLoop:
while
num_activations
<
num_microbatch
:
while
num_activations
<
num_microbatch
:
# 1 Forward
# 1 Forward
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
True
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_activations
,
self
.
chunks
,
forward
=
True
)
message
=
self
.
event_loop_trunk_forward_helper
(
activations
)
message
=
self
.
event_loop_trunk_forward_helper
(
activations
)
num_activations
+=
1
num_activations
+=
1
# 1 Backward
# 1 Backward
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
event_loop_trunk_backward_helper
(
activations
)
self
.
event_loop_trunk_backward_helper
(
activations
)
num_gradients
+=
1
num_gradients
+=
1
if
self
.
perform_optimizer_step
(
optimizer
,
num_gradients
):
if
self
.
perform_optimizer_step
(
optimizer
,
num_gradients
):
...
@@ -336,7 +350,7 @@ class AsyncAMPnetEventLoop:
...
@@ -336,7 +350,7 @@ class AsyncAMPnetEventLoop:
remaining
=
len
(
activations
)
remaining
=
len
(
activations
)
for
_
in
range
(
remaining
):
for
_
in
range
(
remaining
):
if
self
.
weight_prediction
:
if
self
.
weight_prediction
:
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
forward
=
False
)
optimizer
.
update_weight_using_future_predictions
(
cur_rank
,
N
,
num_gradients
,
self
.
chunks
,
forward
=
False
)
self
.
event_loop_trunk_backward_helper
(
activations
)
self
.
event_loop_trunk_backward_helper
(
activations
)
num_gradients
+=
1
num_gradients
+=
1
if
self
.
perform_optimizer_step
(
optimizer
,
num_gradients
):
if
self
.
perform_optimizer_step
(
optimizer
,
num_gradients
):
...
...
fairscale/experimental/nn/ampnet_pipe/pipe.py
View file @
e141a93e
...
@@ -56,6 +56,7 @@ class AMPnetPipe(AsyncPipe):
...
@@ -56,6 +56,7 @@ class AMPnetPipe(AsyncPipe):
weight_prediction
,
weight_prediction
,
checkpoint_stop
,
checkpoint_stop
,
self
.
input_device
,
self
.
input_device
,
self
.
chunks
,
)
)
if
rank
==
0
:
if
rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment