Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
86da10e9
Commit
86da10e9
authored
Dec 07, 2021
by
Lawrence McAfee
Browse files
working for pure pipeline parallelism, w/ no interleaving
parent
d4169684
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
2 deletions
+34
-2
megatron/schedules.py
megatron/schedules.py
+34
-2
No files found.
megatron/schedules.py
View file @
86da10e9
...
...
@@ -15,6 +15,7 @@
from
contextlib
import
contextmanager
import
torch
from
torch.autograd.variable
import
Variable
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
...
...
@@ -27,7 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
get_forward_backward_func
():
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
...
...
@@ -42,6 +42,36 @@ def get_forward_backward_func():
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
def
free_output_tensor
(
t
):
t
.
data
=
torch
.
FloatTensor
([
0
]).
to
(
t
.
data
)
def
custom_backward
(
output
,
grad_output
):
assert
output
.
numel
()
==
1
,
\
"output should be pseudo-'freed' in schedule, to optimize memory"
assert
isinstance
(
output
,
torch
.
Tensor
),
\
"output == '%s'."
%
type
(
output
).
__name__
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))),
\
"grad_output == '%s'."
%
type
(
grad_output
).
__name__
# Handle scalar output
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
,
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
output
,),
grad_tensors
=
(
grad_output
,),
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
tuple
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step for passed-in model.
...
...
@@ -116,7 +146,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
[
0
])
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
custom_
backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
...
...
@@ -560,6 +590,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
[
free_output_tensor
(
t
)
for
t
in
output_tensor
]
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
...
...
@@ -588,6 +619,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers
=
timers
)
# Add input_tensor and output_tensor to end of list.
[
free_output_tensor
(
t
)
for
t
in
output_tensor
]
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment