"...text-generation-inference.git" did not exist on "55bd4fed7da83a566dca08b0bb29dbc5929a90eb"
Unverified Commit 0da60e10 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

T5 pipeline parallel changes (#1279)

* Free output tensor on each pipeline stage for smaller memory footprint

see:
https://github.com/NVIDIA/Megatron-LM/commit/057b086c689b164864455430c223ab52fd86bbcb

* ref: https://github.com/NVIDIA/Megatron-LM/commit/945ece943149b63511e9d0ec3df8effe7f3c13ff

* ref: https://github.com/NVIDIA/Megatron-LM/commit/9a8b89acd8f6ba096860170d0e30ddc0bc2bacd4

* remove position embedding group in destroy

* pass deallocate_pipeline_outputs to backward_step

* fix typo

* missing deallocate_pipeline_outputs

* fix typo: grad_ouptut -> grad_output

* update tests

* remove accessed todo

* test with data parallel size of 2 if there's equal to or more than 8 gpus
parent a47d1a76
...@@ -166,6 +166,15 @@ def initialize_model_parallel( ...@@ -166,6 +166,15 @@ def initialize_model_parallel(
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if (
pipeline_model_parallel_split_rank_ is not None and
ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks
):
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_], ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_]]
else: else:
embedding_ranks = ranks embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks) group = torch.distributed.new_group(embedding_ranks)
...@@ -428,6 +437,8 @@ def destroy_model_parallel(): ...@@ -428,6 +437,8 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None _EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
......
# NOTE(mkozuki): For simplicity, tentatively `timers` related operations are commented out.
from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch import torch
from torch.autograd.variable import Variable
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType from apex.transformer.enums import ModelType
...@@ -145,6 +145,54 @@ def _get_params_for_weight_decay_optimization( ...@@ -145,6 +145,54 @@ def _get_params_for_weight_decay_optimization(
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
def free_output_tensor(
output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
deallocate_pipeline_outputs: bool = False
) -> None:
"""Pseudo-free the output tensor's `.data` field.
This method should be called right after the output tensor has been sent to the next
pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field,
and not its `.data`.
"""
if not deallocate_pipeline_outputs:
return
if output_tensors is None:
return
if isinstance(output_tensors, torch.Tensor):
output_tensors = [output_tensors]
for output_tensor in output_tensors:
output_tensor.data = torch.cuda.FloatTensor([0])
def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -> None:
"""Directly call C++ autograd engine.
To make the `free_output_tensor` optimization work, the C++ autograd engine must be called
directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the
output and grad have the same shape, while C++ `backward` does not.
"""
assert output.numel() == 1, "output should be pseudo-freed in schedule, to optimize memory consumption"
assert isinstance(output, torch.Tensor), "output == {}.".format(type(output).__name__)
assert isinstance(grad_output, (torch.Tensor, type(None))), "grad_outptu == {}.".format(type(grad_output).__name__)
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "Implicit grad requires scalar output."
grad_output = torch.ones_like(output, memory_format=torch.preserve_format)
# Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors=(output,),
grad_tensors=(grad_output,),
keep_graph=False,
create_graph=False,
inputs=(),
allow_unreachable=True,
accumulate_grad=True,
)
def forward_step( def forward_step(
forward_step_func: FwdStepFunc, forward_step_func: FwdStepFunc,
batch: Optional[Batch], batch: Optional[Batch],
...@@ -214,6 +262,7 @@ def backward_step( ...@@ -214,6 +262,7 @@ def backward_step(
model_type: ModelType, model_type: ModelType,
*, *,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
deallocate_pipeline_outputs: bool = False,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]: ) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor. """Backward step through passed-in output tensor.
...@@ -250,7 +299,10 @@ def backward_step( ...@@ -250,7 +299,10 @@ def backward_step(
# Backward pass. # Backward pass.
if grad_scaler is not None and output_tensor_grad[0] is None: if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0]) output_tensor[0] = grad_scaler.scale(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) if deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = [None] input_tensor_grad = [None]
......
...@@ -8,6 +8,7 @@ from apex.transformer.pipeline_parallel.schedules.common import Batch ...@@ -8,6 +8,7 @@ from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type from apex.transformer.pipeline_parallel.utils import get_model_type
...@@ -31,6 +32,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -31,6 +32,7 @@ def _forward_backward_pipelining_with_interleaving(
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False, disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed. """Run interleaved 1F1B schedule with communication between pipeline stages as needed.
...@@ -164,7 +166,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -164,7 +166,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs)
return input_tensor_grad return input_tensor_grad
...@@ -217,6 +219,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -217,6 +219,7 @@ def _forward_backward_pipelining_with_interleaving(
_logger.debug("send fwd and receive fwd") _logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward( input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype) output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
################################################################################################################### ###################################################################################################################
...@@ -293,6 +296,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -293,6 +296,7 @@ def _forward_backward_pipelining_with_interleaving(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype=dtype, dtype=dtype,
) )
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
......
...@@ -11,8 +11,9 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches ...@@ -11,8 +11,9 @@ from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_model_type from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.pipeline_parallel.schedules.common import Batch from apex.transformer.pipeline_parallel.schedules.common import Batch
from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import backward_step from apex.transformer.pipeline_parallel.schedules.common import backward_step
from apex.transformer.pipeline_parallel.schedules.common import forward_step
from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor
from apex.transformer.log_util import get_transformer_logger from apex.transformer.log_util import get_transformer_logger
...@@ -162,6 +163,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -162,6 +163,7 @@ def forward_backward_pipelining_without_interleaving(
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False, disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages. """Run non-interleaved 1F1B schedule, with communication between pipeline stages.
...@@ -246,6 +248,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -246,6 +248,7 @@ def forward_backward_pipelining_without_interleaving(
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
...@@ -287,12 +290,20 @@ def forward_backward_pipelining_without_interleaving( ...@@ -287,12 +290,20 @@ def forward_backward_pipelining_without_interleaving(
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for the backward pass. # Pop input_tensor and output_tensor from the start of the list for the backward pass.
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler) input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
...@@ -314,7 +325,14 @@ def forward_backward_pipelining_without_interleaving( ...@@ -314,7 +325,14 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug("receive bwd") _logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler) input_tensor_grad = backward_step(
input_tensor,
output_tensor,
output_tensor_grad,
model_type=model_type,
grad_scaler=grad_scaler,
deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
_logger.debug("send bwd") _logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
......
import itertools
from typing import Optional from typing import Optional
import warnings import warnings
...@@ -46,8 +47,15 @@ def forward_backward_func_template( ...@@ -46,8 +47,15 @@ def forward_backward_func_template(
forward_only: bool, forward_only: bool,
dtype: torch.dtype, dtype: torch.dtype,
grad_scaler: Optional[GradScaler], grad_scaler: Optional[GradScaler],
deallocate_pipeline_outputs: bool,
data_parallel_size: int,
) -> None: ) -> None:
print_separator(f"name: {name}, dtype: {dtype}, use grad_scaler: {grad_scaler is not None}, pipeline model parallel size: {pipeline_model_parallel_size}") print_separator(
f"{name}, {dtype}, use grad_scaler: {grad_scaler is not None}, "
f"deallocate_pipeline_outputs: {deallocate_pipeline_outputs}, "
f"pipeline parallel size: {pipeline_model_parallel_size}, "
f"data parallel size: {data_parallel_size}"
)
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
if name == "no_pipelining": if name == "no_pipelining":
# note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with # note (mkozuki): `forward_backward_no_pipelining` is **NOT** compatible with
...@@ -66,13 +74,13 @@ def forward_backward_func_template( ...@@ -66,13 +74,13 @@ def forward_backward_func_template(
# In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and
# used ubiquitously but this test uses custom model so it's safe to abuse. # used ubiquitously but this test uses custom model so it's safe to abuse.
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size) data_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size)
_reconfigure_microbatch_calculator( _reconfigure_microbatch_calculator(
args.rank, args.rank,
args.rampup_batch_size, args.rampup_batch_size,
args.global_batch_size, args.global_batch_size,
args.micro_batch_size, args.micro_batch_size,
1, # args.data_parallel_size, parallel_state.get_data_parallel_world_size(),
) )
if virtual_pipeline_model_parallel_size is not None: if virtual_pipeline_model_parallel_size is not None:
# Check the experimental warning message # Check the experimental warning message
...@@ -96,7 +104,9 @@ def forward_backward_func_template( ...@@ -96,7 +104,9 @@ def forward_backward_func_template(
update_num_microbatches(0) update_num_microbatches(0)
forward_backward_func( forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype, grad_scaler=grad_scaler) fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape,
dtype=dtype, grad_scaler=grad_scaler, deallocate_pipeline_outputs=deallocate_pipeline_outputs,
)
if not forward_only: if not forward_only:
for m in model: for m in model:
...@@ -121,44 +131,46 @@ if __name__ == "__main__": ...@@ -121,44 +131,46 @@ if __name__ == "__main__":
micro_batch_size = args.micro_batch_size micro_batch_size = args.micro_batch_size
dtypes = [torch.float32] + _get_autocast_dtypes() dtypes = [torch.float32] + _get_autocast_dtypes()
for forward_only in (True, False): for forward_only, name, dtype, deallocate_pipeline_outputs in itertools.product(
for name, forward_backward_func in fwd_bwd_functions.items(): (True, False),
if name == "interleaving" and torch.cuda.device_count() <= 2: fwd_bwd_functions.keys(),
warnings.warn( dtypes,
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} " (True, False),
"while interleaved scheduled pipeline parallel requires >2 gpus." ):
) forward_backward_func = fwd_bwd_functions[name]
continue if name == "interleaving" and torch.cuda.device_count() <= 2:
for dtype in dtypes: warnings.warn(
if torch.distributed.get_rank() == 0: f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
_logger.info(f"forward_only: {forward_only}, name: {name}, dtype: {dtype}") "while interleaved scheduled pipeline parallel requires >2 gpus."
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None )
n_tests += 1 continue
# TODO (mkozuki): Test with data parallel size > 1. grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
pipeline_model_parallel_size = world_size n_tests += 1
try: data_parallel_size = 2 if world_size >= 8 and world_size % 2 == 0 else 1
forward_backward_func_template( pipeline_model_parallel_size = world_size if world_size < 8 else world_size // 2
args, try:
name, forward_backward_func_template(
forward_backward_func, args,
pipeline_model_parallel_size, name,
forward_only, forward_backward_func,
dtype=dtype, pipeline_model_parallel_size,
grad_scaler=grad_scaler, forward_only,
) dtype=dtype,
except Exception as e: grad_scaler=grad_scaler,
failures.append( deallocate_pipeline_outputs=deallocate_pipeline_outputs,
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} " data_parallel_size=data_parallel_size,
f"and forward_only: {forward_only}\n" )
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, " except Exception as e:
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n" failures.append(
f"{str(e)}" f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
) f"and forward_only: {forward_only}\n"
print(failures[-1]) f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
finally: f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
parallel_state.destroy_model_parallel() f"{str(e)}"
else: )
print_separator(f"{name} works") print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
print_separator("TEST RESULT") print_separator("TEST RESULT")
if failures: if failures:
torch.distributed.barrier() torch.distributed.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment