Unverified Commit 0da60e10 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

T5 pipeline parallel changes (#1279)

* Free output tensor on each pipeline stage for smaller memory footprint

see:
https://github.com/NVIDIA/Megatron-LM/commit/057b086c689b164864455430c223ab52fd86bbcb

* ref: https://github.com/NVIDIA/Megatron-LM/commit/945ece943149b63511e9d0ec3df8effe7f3c13ff

* ref: https://github.com/NVIDIA/Megatron-LM/commit/9a8b89acd8f6ba096860170d0e30ddc0bc2bacd4

* remove position embedding group in destroy

* pass deallocate_pipeline_outputs to backward_step

* fix typo

* missing deallocate_pipeline_outputs

* fix typo: grad_ouptut -> grad_output

* update tests

* remove accessed todo

* test with data parallel size of 2 if there's equal to or more than 8 gpus
parent a47d1a76
......@@ -166,6 +166,15 @@ def initialize_model_parallel(
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if (
pipeline_model_parallel_split_rank_ is not None and
ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks
):
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_], ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_]]
else:
embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
......@@ -428,6 +437,8 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
......
# NOTE(mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch
from torch.autograd.variable import Variable
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
......@@ -145,6 +145,54 @@ def _get_params_for_weight_decay_optimization(
return weight_decay_params, no_weight_decay_params
def free_output_tensor(
output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
deallocate_pipeline_outputs: bool = False
) -> None:
"""Pseudo-free the output tensor's `.data` field.
This method should be called right after the output tensor has been sent to the next
pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field,
and not its `.data`.
"""
if not deallocate_pipeline_outputs:
return
if output_tensors is None:
return
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
for output_tensor in output_tensors:
output_tensor.data = torch.cuda.FloatTensor([0])
def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -> None:
"""Directly call C++ autograd engine.
To make the `free_output_tensor` optimization work, the C++ autograd engine must be called
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not.
"""
assert output.numel() == 1, "output should be pseudo-freed in schedule, to optimize memory consumption"
assert isinstance(output, torch.Tensor), "output == {}.".format(type(output).__name__)
assert isinstance(grad_output, (torch.Tensor, type(None))), "grad_outptu == {}.".format(type(grad_output).__name__)
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "Implicit grad requires scalar output."
grad_output = torch.ones_like(output, memory_format=torch.preserve_format)
# Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors=(output,),
grad_tensors=(grad_output,),
keep_graph=False,
create_graph=False,
inputs=(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step(
forward_step_func: FwdStepFunc,
batch: Optional[Batch],
......@@ -214,6 +262,7 @@ def backward_step(
model_type: ModelType,
*,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
deallocate_pipeline_outputs: bool = False,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor.
......@@ -250,7 +299,10 @@ def backward_step(
# Backward pass.
if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
if deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
......
......@@ -8,6 +8,7 @@ from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type
......@@ -31,6 +32,7 @@ def _forward_backward_pipelining_with_interleaving(
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
......@@ -164,7 +166,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs)
return input_tensor_grad
......@@ -217,6 +219,7 @@ def _forward_backward_pipelining_with_interleaving(
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
###################################################################################################################
......@@ -293,6 +296,7 @@ def _forward_backward_pipelining_with_interleaving(
tensor_shape=tensor_shape,
dtype=dtype,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
......
......@@ -11,8 +11,9 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor
from apex.transformer.log_util import get_transformer_logger
......@@ -162,6 +163,7 @@ def forward_backward_pipelining_without_interleaving(
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -246,6 +248,7 @@ def forward_backward_pipelining_without_interleaving(
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
......@@ -287,12 +290,20 @@ def forward_backward_pipelining_without_interleaving(
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if last_iteration:
input_tensor = None
......@@ -314,7 +325,14 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
......
import itertools
from typing import Optional
import warnings
......@@ -46,8 +47,15 @@ def forward_backward_func_template(
forward_only: bool,
dtype: torch.dtype,
grad_scaler: Optional[GradScaler],
deallocate_pipeline_outputs: bool,
data_parallel_size: int,
) -> None:
print_separator(f"name: {name}, dtype: {dtype}, use grad_scaler: {grad_scaler is not None}, pipeline model parallel size: {pipeline_model_parallel_size}")
print_separator(
f"{name}, {dtype}, use grad_scaler: {grad_scaler is not None}, "
f"deallocate_pipeline_outputs: {deallocate_pipeline_outputs}, "
f"pipeline parallel size: {pipeline_model_parallel_size}, "
f"data parallel size: {data_parallel_size}"
)
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
if name == "no_pipelining":
# note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with
......@@ -66,13 +74,13 @@ def forward_backward_func_template(
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
data_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
_reconfigure_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
1, # args.data_parallel_size,
parallel_state.get_data_parallel_world_size(),
)
if virtual_pipeline_model_parallel_size is not None:
# Check the experimental warning message
......@@ -96,7 +104,9 @@ def forward_backward_func_template(
update_num_microbatches(0)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype, grad_scaler=grad_scaler)
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape,
dtype=dtype, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if not forward_only:
for m in model:
......@@ -121,44 +131,46 @@ if __name__ == "__main__":
micro_batch_size = args.micro_batch_size
dtypes = [torch.float32] + _get_autocast_dtypes()
for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items():
if name == "interleaving" and torch.cuda.device_count() <= 2:
warnings.warn(
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
continue
for dtype in dtypes:
if torch.distributed.get_rank() == 0:
_logger.info(f"forward_only: {forward_only}, name: {name}, dtype: {dtype}")
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size
try:
forward_backward_func_template(
args,
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
dtype=dtype,
grad_scaler=grad_scaler,
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
else:
print_separator(f"{name} works")
for forward_only, name, dtype, deallocate_pipeline_outputs in itertools.product(
(True, False),
fwd_bwd_functions.keys(),
dtypes,
(True, False),
):
forward_backward_func = fwd_bwd_functions[name]
if name == "interleaving" and torch.cuda.device_count() <= 2:
warnings.warn(
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
continue
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
n_tests += 1
data_parallel_size = 2 if world_size >= 8 and world_size % 2 == 0 else 1
pipeline_model_parallel_size = world_size if world_size < 8 else world_size // 2
try:
forward_backward_func_template(
args,
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
dtype=dtype,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
data_parallel_size=data_parallel_size,
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
print_separator("TEST RESULT")
if failures:
torch.distributed.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment