Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2f25c570
Commit
2f25c570
authored
Dec 07, 2021
by
Lawrence McAfee
Browse files
working: interleaving; free_output_tensor() now handles none/tensor/list
parent
86da10e9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
megatron/schedules.py
megatron/schedules.py
+11
-4
No files found.
megatron/schedules.py
View file @
2f25c570
...
...
@@ -42,8 +42,13 @@ def get_forward_backward_func():
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
def
free_output_tensor
(
t
):
t
.
data
=
torch
.
FloatTensor
([
0
]).
to
(
t
.
data
)
def
free_output_tensor
(
output_tensors
):
if
output_tensors
is
None
:
return
if
isinstance
(
output_tensors
,
torch
.
Tensor
):
output_tensors
=
[
output_tensors
]
for
output_tensor
in
output_tensors
:
output_tensor
.
data
=
torch
.
FloatTensor
([
0
]).
to
(
output_tensor
.
data
)
def
custom_backward
(
output
,
grad_output
):
...
...
@@ -354,6 +359,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
free_output_tensor
(
output_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
...
...
@@ -418,6 +424,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
timers
=
timers
)
free_output_tensor
(
output_tensor
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
...
...
@@ -590,9 +597,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
[
free_output_tensor
(
t
)
for
t
in
output_tensor
]
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
...
...
@@ -619,9 +626,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
[
free_output_tensor
(
t
)
for
t
in
output_tensor
]
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
free_output_tensor
(
output_tensor
)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment