Commit 86da10e9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working for pure pipeline parallelism, w/ no interleaving

parent d4169684
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
from torch.autograd.variable import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
...@@ -27,7 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -27,7 +28,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType from megatron.model import ModelType
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
...@@ -42,6 +42,36 @@ def get_forward_backward_func(): ...@@ -42,6 +42,36 @@ def get_forward_backward_func():
forward_backward_func = forward_backward_no_pipelining forward_backward_func = forward_backward_no_pipelining
return forward_backward_func return forward_backward_func
def free_output_tensor(t):
t.data = torch.FloatTensor([0]).to(t.data)
def custom_backward(output, grad_output):
assert output.numel() == 1, \
"output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), \
"output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), \
"grad_output == '%s'." % type(grad_output).__name__
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "implicit grad requires scalar output."
grad_output = torch.ones_like(
output,
memory_format = torch.preserve_format,
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors = (output,),
grad_tensors = (grad_output,),
keep_graph = False,
create_graph = False,
inputs = tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -116,7 +146,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -116,7 +146,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass. # Backward pass.
if output_tensor_grad[0] is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0]) output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) custom_backward(output_tensor[0], output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = [None] input_tensor_grad = [None]
...@@ -560,6 +590,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -560,6 +590,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -588,6 +619,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -588,6 +619,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
timers=timers) timers=timers)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
[ free_output_tensor(t) for t in output_tensor ]
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment