Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3e6898e6
"server/Makefile-flashinfer" did not exist on "4e821c003a7cb055a358cf142dbf01a2f4c1916f"
Commit
3e6898e6
authored
Oct 24, 2020
by
Deepak Narayanan
Browse files
Move training schedule to 1F1B for memory efficiency
parent
6abf39be
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
12 deletions
+80
-12
megatron/training.py
megatron/training.py
+80
-12
No files found.
megatron/training.py
View file @
3e6898e6
...
@@ -357,6 +357,60 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
...
@@ -357,6 +357,60 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
timers
(
'backward-send'
).
stop
()
timers
(
'backward-send'
).
stop
()
def
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_microbatch
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
output_tensor_grad
=
None
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send'
).
start
()
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'forward-send'
).
stop
()
timers
(
'backward-recv'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send'
).
start
()
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
(
not
last_microbatch
),
recv_backward
=
False
)
timers
(
'backward-send'
).
stop
()
timers
(
'forward-recv'
).
stop
()
else
:
input_tensor
=
None
return
input_tensor
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
"""Single training step."""
...
@@ -371,18 +425,12 @@ def train_step(forward_step_func, data_iterator,
...
@@ -371,18 +425,12 @@ def train_step(forward_step_func, data_iterator,
# Compute number of microbatches in a minibatch.
# Compute number of microbatches in a minibatch.
num_microbatches_in_minibatch
=
args
.
num_microbatches_in_minibatch
num_microbatches_in_minibatch
=
args
.
num_microbatches_in_minibatch
# For now, perform training without warmup. Perform forward
num_warmup_microbatches
=
\
# passes for all microbatches, then backward passes for all
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
# microbatches.
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
# TODO: Switch to the following schedule to facilitate more
num_warmup_microbatches
=
min
(
# memory-efficient training.
num_warmup_microbatches
,
# num_warmup_microbatches = \
num_microbatches_in_minibatch
)
# (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) -
# torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1)
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
num_warmup_microbatches
=
num_microbatches_in_minibatch
input_tensors
=
[]
input_tensors
=
[]
output_tensors
=
[]
output_tensors
=
[]
...
@@ -407,6 +455,26 @@ def train_step(forward_step_func, data_iterator,
...
@@ -407,6 +455,26 @@ def train_step(forward_step_func, data_iterator,
timers
(
'forward-compute'
).
stop
()
timers
(
'forward-compute'
).
stop
()
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
# Before running 1F1B, need to receive first forward tensor.
if
(
num_microbatches_in_minibatch
-
num_warmup_microbatches
)
>
0
:
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
# Run 1F1B.
for
i
in
range
(
num_microbatches_in_minibatch
-
num_warmup_microbatches
):
last_iteration
=
(
i
==
(
num_microbatches_in_minibatch
-
num_warmup_microbatches
-
1
))
input_tensor
=
\
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_iteration
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Run cooldown backward passes.
# Run cooldown backward passes.
timers
(
'backward'
).
start
()
timers
(
'backward'
).
start
()
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment