Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0865c4dc
Commit
0865c4dc
authored
Aug 09, 2021
by
Lawrence McAfee
Browse files
removed saving of input/output tensors for forward-only passes of pipeline schedules
parent
6a680986
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
13 deletions
+26
-13
megatron/schedules.py
megatron/schedules.py
+26
-13
No files found.
megatron/schedules.py
View file @
0865c4dc
...
@@ -194,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -194,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
# forward step
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
len
(
output_tensors
[
model_chunk_id
]):
...
@@ -205,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -205,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
def
backward_step_helper
(
microbatch_id
):
...
@@ -383,8 +389,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -383,8 +389,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
# Input, output tensors only need to be saved when doing backward passes
output_tensors
=
[]
input_tensors
=
None
output_tensors
=
None
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
# Run warmup forward passes.
...
@@ -394,8 +404,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -394,8 +404,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
input_tensors
.
append
(
input_tensor
)
if
not
forward_only
:
output_tensors
.
append
(
output_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Before running 1F1B, need to receive first forward tensor.
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
...
@@ -411,21 +422,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -411,21 +422,23 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
if
forward_only
:
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
=
timers
)
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
else
:
else
:
output_tensor_grad
=
\
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
=
timers
)
timers
=
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# Add input_tensor and output_tensor to end of list.
# start of the list for backward pass.
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
if
forward_only
:
# Pop input_tensor and output_tensor from the start of the list for
if
not
last_iteration
:
# the backward pass.
input_tensor
=
p2p_communication
.
recv_forward
(
timers
=
timers
)
input_tensor
=
input_tensors
.
pop
(
0
)
else
:
output_tensor
=
output_tensors
.
pop
(
0
)
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment