Commit 86da10e9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working for pure pipeline parallelism, w/ no interleaving

parent d4169684
......@@ -15,6 +15,7 @@
from contextlib import contextmanager
import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
......@@ -27,7 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
......@@ -42,6 +42,36 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def free_output_tensor(t):
t.data = torch.FloatTensor([0]).to(t.data)
def custom_backward(output, grad_output):
assert output.numel() == 1, \
"output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), \
"output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), \
"grad_output == '%s'." % type(grad_output).__name__
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "implicit grad requires scalar output."
grad_output = torch.ones_like(
output,
memory_format = torch.preserve_format,
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors = (output,),
grad_tensors = (grad_output,),
keep_graph = False,
create_graph = False,
inputs = tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
......@@ -116,7 +146,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
......@@ -560,6 +590,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
......@@ -588,6 +619,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers)
# Add input_tensor and output_tensor to end of list.
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment