"docs/source/nonpytorchcuda.mdx" did not exist on "d76b6ca91b827b5c522bb794d96628d290ee29f6"
Unverified Commit 3ff1a10f authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Port Sequence Parallelism (takeover of #1396) (#1400)

* it looks possible to remove this file

* add communication collectives

* update Column|RowParallelLinear

* update checkpoint function

* update function name

* parity between public and private collectives

* row parallel linear

* column parallel linear

* sequence parallel: p2p comm

fix typo

* sequence parallel: pipeline parallel

* fix typo

* add layernorm with sequence_parallel_enabled attr

* class variable -> member variable

* fix col parallel test with sequence parallel

* Initial test of `forward_backward_pipelining_without_interleaving` with `model_type=ModelType.encoder_and_decoder`

* add cases pretending to test sequence_parallel

* Apply 2 suggestion(s) to 1 file(s)

* update sequence_parallel_enabled docstring

* update docstring: order of tensor dimensions, sequence_parallel_enabled behavior

* Divide sequence_length if sequence parallel

tensor shape should be updated if sequence parallel is enabled.

* cherry-pick https://github.com/NVIDIA/Megatron-LM/commit/8474e6e54fcb9dfa37aea039352f9fb485fb6f61

* type annotation

* Fix matmul call in RowParallelLinear

Fix `sequence_parallel_enabled` to `False` as you can see in
https://github.com/NVIDIA/Megatron-LM/blob/d898a8991d1a08d29074f87819d1bf41517e35f5/megatron/mpu/layers.py#L511-L514

* update rowparallellinear test

* fix `loss_weight` is not defined in test_layers

* @eqy's comment

* mixed fused layer norm

* fix typo

* misc

* test_layers cleanup

* Skip Bert/GPT script

Since these two models haven't gotten updated for sequence parallle, e.g. the update of the order of dimension from (batch, sequence, feature) to (sequence, batch, feature) and global variables of arguments

* debug part 1/N: comment out `x.retain_grad`

* debug part 2/N: [ColumnParallelLinear] comment out overriding of sequence_parallel_enabled

* debug 3/N: add pipeline test with parallel mlp

* Fix handling `self.input_tensor` and argument

* tp2pp4 ModelType.encoder_or_decoder is failing, which can be at my fault because the backward is blaming the output and the grad_ouptut shape don't match

* revert debug 1/N

* defer tensor model parallel size > 1

* split tensor in sequence dim

* cosmetic

* cosmetic: remove archaic comment

* enable TP>1 for encoder_and_decoder as well

* set requires_grad=True always...

* Set `scatter_gather_tensors_in_pipeline` to :obj:`False`

for the sake of nemo megatron's GPT works with sequence parallel enabled.

* brush up comment of `requires_grad()`

There's a possibility that PyTorch DistributedDataParallel hangs
when some tensor (or parameter) doesn't require grad according to @ptrblck.
This forced `requires_grad` in my understanding is different from that.

* misc changes of scatter_gather_tensors_in_pipeline comment

* guard for torch_ucc

* cosmetic changes related to tests

* update command line arguments

* update TransformerLanguageModel

* rename

* move gpt to gpt.py

* update bert

* add all_gather for params in sequence parallel region

* misc. some diffs were lost during rebasing...

* updates for non sequence parallel execution

* gpt with sequence parallel

* Apply 2 suggestion(s) to 2 file(s)

* update tensor&pipeline parallel size

* why `sequence_parallel_enabled` is not supplied!? Did I messed up when rebasing?

* cosmetic fix

* correct key is sequence_parallel_enabled
parent 57f890a7
...@@ -8,10 +8,9 @@ import torch ...@@ -8,10 +8,9 @@ import torch
from torch.testing._internal import common_utils from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda from torch.testing._internal import common_cuda
logging.getLogger("torch").setLevel(logging.WARNING)
from apex._autocast_utils import _get_autocast_dtypes from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import utils as pp_utils from apex.transformer.pipeline_parallel import utils as pp_utils
from apex.transformer.pipeline_parallel.schedules.common import ( from apex.transformer.pipeline_parallel.schedules.common import (
FwdStepFunc, FwdStepFunc,
...@@ -33,6 +32,8 @@ from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC ...@@ -33,6 +32,8 @@ from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC
from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER
from apex.transformer.testing import commons as testing_utils from apex.transformer.testing import commons as testing_utils
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING) logging.getLogger("apex").setLevel(logging.WARNING)
weight_coeff = 1024 weight_coeff = 1024
...@@ -300,5 +301,127 @@ class UccPipelineParallelForwardBackwardTest(UccDistributedTestBase, PipelinePar ...@@ -300,5 +301,127 @@ class UccPipelineParallelForwardBackwardTest(UccDistributedTestBase, PipelinePar
dtypes = (torch.float32,) dtypes = (torch.float32,)
# Sanity checking the functionality of `forward_backward_pipelining_without_interleaving` with
# `model_type=ModelType.encoder_and_decoder` which is used for pipeline training of transformer
# models such as T5.
@unittest.skipIf(torch.cuda.device_count() < 4, "Requires >= 4 GPUs")
class NcclPipelineParallelWithToyParallelMLP(NcclDistributedTestBase):
GLOBAL_BATCH_SIZE = 16
MICRO_BATCH_SIZE = 2
HIDDEN_SIZE = 64
# TODO(mkozuki): Change `DECODER_SEQUENCE_LENGTH` to a value different from `ENCODER_SEQUENCE_LENGTH`.
# To test forward_backward_pipelining_without_interleaving with `model_type=ModelType.encoder_and_decoder`,
# `decoder_seq_length` is necessary and ideally should be different from `encoder_sequence_length`
# but my laziness let me use the same value.
# Note that you may have to either update `MyModel` def or define another `MyModel`.
# to support different `DECODER_SEQUENCE_LENGTH`.
ENCODER_SEQUENCE_LENGTH = 32
DECODER_SEQUENCE_LENGTH = 32
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
# TODO(mkozuki): Add cases of async_comm=True
# TODO(mkozuki): Add loss check.
# TODO(mkozuki): Call `build_model` with `model_type`.
# TODO(mkozuki): Set `tensor_model_parallel>1` for encoder_and_decoder as well if there's enough GPUs
# in order to let `sequence_parallel_enabled` have an effect on tensor shape logic.
def _forward_backward_test_impl(
self,
*,
forward_only: bool,
sequence_parallel_enabled: bool,
model_type: ModelType,
dtype: torch.dtype = torch.float32,
) -> None:
# N.B.(mkozuki): It might be better to set `tensor_model_parallel_size` to >1
# if `self.world_size > 5`. Otherwise, `pipeline_model_parallel_split_rank`
# can be 1, which can be too far real usecase.
tensor_model_parallel_size = 1 + int(self.world_size >= 4)
pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_size
if model_type == ModelType.encoder_and_decoder:
pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2
else:
pipeline_model_parallel_split_rank = None
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=None,
pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
)
testing_utils.set_random_seed(567)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
model = build_model(
testing_utils.mlp_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=None,
hidden_size=self.HIDDEN_SIZE,
sequence_parallel_enabled=sequence_parallel_enabled,
)
model = [m.to(dtype=dtype) for m in model]
if parallel_state.is_pipeline_first_stage():
batch: Tuple[torch.Tensor] = (
torch.ones(
(self.GLOBAL_BATCH_SIZE, self.ENCODER_SEQUENCE_LENGTH, self.HIDDEN_SIZE),
dtype=dtype,
device="cuda",
),
)
else:
batch = None
forward_backward_pipelining_without_interleaving(
forward_step_func=testing_utils.ToyParallelMLPFwdBwdStepFunc(
sequence_parallel_enabled=sequence_parallel_enabled,
),
batch=batch,
model=model,
forward_only=forward_only,
tensor_shape=(
self.ENCODER_SEQUENCE_LENGTH,
self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE,
),
model_type=model_type,
decoder_sequence_length=self.DECODER_SEQUENCE_LENGTH,
async_comm=False,
grad_scaler=None,
deallocate_pipeline_outputs=False,
dtype=dtype,
sequence_parallel_enabled=sequence_parallel_enabled,
)
def test_pipelining_without_interleaving_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=False, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_inferenc_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=True, sequence_parallel_enabled=False, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_sequence_paralle_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_inference_sequence_paralle_encoder_and_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=True, sequence_parallel_enabled=True, model_type=ModelType.encoder_and_decoder)
def test_pipelining_without_interleaving_encoder_or_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=False, model_type=ModelType.encoder_or_decoder)
def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_or_decoder)
def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder_half(self) -> None:
self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_or_decoder, dtype=torch.half)
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()
...@@ -33,8 +33,6 @@ def get_launch_option(test_filename) -> Tuple[bool, str]: ...@@ -33,8 +33,6 @@ def get_launch_option(test_filename) -> Tuple[bool, str]:
def run_transformer_tests(): def run_transformer_tests():
python_executable_path = sys.executable python_executable_path = sys.executable
# repository_root = os.path.join(os.path.dirname(__file__), "../../../")
# directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
directory = os.path.dirname(__file__) directory = os.path.dirname(__file__)
files = [ files = [
os.path.join(directory, f) os.path.join(directory, f)
...@@ -63,9 +61,17 @@ def run_transformer_tests(): ...@@ -63,9 +61,17 @@ def run_transformer_tests():
import torch import torch
num_devices = torch.cuda.device_count() num_devices = torch.cuda.device_count()
tensor_model_parallel_size = 1 + (1 - (num_devices % 2 and num_devices > 4)) if "bert" in test_file:
# "bert" uses the interleaving.
tensor_model_parallel_size = 2 if num_devices % 2 == 0 and num_devices > 4 else 1
if "gpt" in test_file:
# "gpt" uses the non-interleaving.
tensor_model_parallel_size = 2 if num_devices % 2 == 0 and num_devices >= 4 else 1
pipeline_model_parallel_size = num_devices // tensor_model_parallel_size pipeline_model_parallel_size = num_devices // tensor_model_parallel_size
test_run_cmd += f" --pipeline-model-parallel-size {pipeline_model_parallel_size} --tensor-model-parallel-size {tensor_model_parallel_size}" test_run_cmd += f" --pipeline-model-parallel-size {pipeline_model_parallel_size} --tensor-model-parallel-size {tensor_model_parallel_size}"
if "bert" in test_file:
test_run_cmd += f" --bert-no-binary-head"
else: else:
test_run_cmd += f" --use-cpu-initialization" test_run_cmd += f" --use-cpu-initialization"
print(f"### {i} / {len(files)}: cmd: {test_run_cmd}") print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment