Commit f3434cc7 authored by dongcl's avatar dongcl
Browse files

fix dualpipev broadcast error

parent 7c63d1a4
from megatron.core.utils import is_te_min_version
def a2a_overlap_adaptation(patches_manager):
"""
patches_manager: MegatronPatchesManager
"""
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from ..core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from ..core.transformer.transformer_layer import TransformerLayer
from ..core.models.gpt.gpt_model import GPTModel
from ..core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from ..core.extensions.transformer_engine import (
_get_extra_te_kwargs_wrapper,
TELinear,
TELayerNormColumnParallelLinear,
)
from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP
from ..core.transformer.moe.moe_layer import MoELayer
# num_warmup_microbatches + 1
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan,
create_dummy=True)
# backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear)
TEColumnParallelLinear.__bases__ = (TELinear,)
TERowParallelLinear.__bases__ = (TELinear,)
if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear
from ..core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear)
TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
TEGroupedMLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw,
create_dummy=True)
......@@ -56,7 +56,7 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving',
forward_backward_pipelining_with_cutinhalf)
patch_manager.register_patch(
'megatron.legacy.model.module.Float16Module.forward', dualpipev_fp16forward)
'megatron.core.transformer.module.Float16Module.forward', dualpipev_fp16forward)
patch_manager.register_patch(
'megatron.core.transformer.transformer_block.get_num_layers_to_build', get_num_layers_to_build)
patch_manager.register_patch(
......
import os
import abc
import sys
import types
import argparse
import torch
......@@ -9,6 +7,10 @@ from megatron.core.utils import is_te_min_version
from .features_manager import ADAPTOR_FEATURES
from .patch_utils import MegatronPatchesManager
from dcu_megatron.training.arguments import process_adaptor_args
_ARGS = None
def add_args(args, key, value):
......@@ -42,7 +44,7 @@ def get_adaptor_args():
global _ARGS
if _ARGS is None:
parser = argparse.ArgumentParser(description='Adaptor Arguments', allow_abbrev=False)
_ARGS, unknown = process_args(parser).parse_known_args()
_ARGS, unknown = process_adaptor_args(parser).parse_known_args()
parser_unknown_args(_ARGS, unknown)
return _ARGS
......@@ -119,7 +121,7 @@ def adaptation_l2(patches_manager, adaptor_args):
"""
for feature in ADAPTOR_FEATURES:
if getattr(adaptor_args, feature.feature_name, None) and feature.optimization_level == 2:
feature.register_patches(patches_manager, mindspeed_args)
feature.register_patches(patches_manager, adaptor_args)
class MegatronAdaptationABC:
......@@ -161,6 +163,7 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_core_transformers(self):
from ..core import transformer_block_init_wrapper
from ..core.transformer.transformer_layer import get_transformer_layer_offset
from ..core.transformer.transformer_config import TransformerConfigPatch, MLATransformerConfigPatch
# Transformer block. If mtp_num_layers > 0, move final_layernorm outside
......@@ -187,6 +190,10 @@ class CoreAdaptation(MegatronAdaptationABC):
torch.compile(mode='max-autotune-no-cudagraphs'),
apply_wrapper=True)
# support dualpipev
MegatronAdaptation.register('megatron.core.transformer.transformer_layer.get_transformer_layer_offset',
get_transformer_layer_offset)
def patch_core_extentions(self):
import transformer_engine as te
......@@ -250,8 +257,9 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.tokenizer import build_tokenizer
from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies
from ..training.training import train
from ..training.training import train, build_train_valid_test_data_iterators_wrapper
from ..training.initialize import _set_random_seed
from ..training.utils import get_batch_on_this_tp_rank
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
......@@ -270,6 +278,15 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.training.train',
train)
# support dualpipev, two data iterators
MegatronAdaptation.register('megatron.training.training.build_train_valid_test_data_iterators',
build_train_valid_test_data_iterators_wrapper,
apply_wrapper=True)
# support dualpipev, broadcast loss_mask and labels
MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank',
get_batch_on_this_tp_rank)
def patch_miscellaneous(self):
from ..training.arguments import parse_args
......
......@@ -4,6 +4,7 @@ from typing import List, Optional
from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config
from megatron.core.transformer.module import Float16Module
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
......@@ -14,7 +15,10 @@ from megatron.core.transformer.module import fp32_to_float16, float16_to_fp32
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core import parallel_state
from megatron.core.distributed.finalize_model_grads import _allreduce_layernorm_grads
from megatron.training.utils import (
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group
)
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_dualpipe_chunk
......@@ -82,7 +86,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# Fp16 conversion.
if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model]
config = get_model_config(model[0])
model = [Float16Module(config, model_module) for model_module in model]
if wrap_with_ddp:
config = get_model_config(model[0])
......@@ -217,10 +222,11 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
def _allreduce_embedding_grads_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if get_args().schedules_method == 'dualpipev':
args = get_args()
if args.schedule_method == 'dualpipev':
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if not get_args().untie_embeddings_and_output_weights:
if not args.untie_embeddings_and_output_weights:
raise NotImplementedError
else:
return
......
......@@ -19,9 +19,8 @@ from megatron.core.utils import (
from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, deallocate_output_tensor
from megatron.core import ModelParallelConfig
from megatron.core.pipeline_parallel.p2p_communication import _communicate
from megatron.core.pipeline_parallel.schedules import backward_step, set_current_microbatch, custom_backward, finish_embedding_wgrad_compute
from megatron.core.models.gpt import GPTModel
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
from megatron.core.pipeline_parallel.schedules import backward_step, set_current_microbatch, finish_embedding_wgrad_compute
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# Types
......@@ -115,7 +114,7 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
return reqs
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False) -> torch.Tensor:
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_id, async_op=False, step=-1) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
......@@ -568,8 +567,6 @@ def forward_backward_pipelining_with_cutinhalf(
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
input_tensors = [[], []]
output_tensors = [[], []]
model_graphs = [[], []]
logits_inputs = []
forward_data_store = []
master_chunk_id = 0
......@@ -584,7 +581,7 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch = None
input_tensor = recv_forward(tensor_shape, config, master_chunk_id)[0]
input_tensor = recv_forward(tensor_shape, config, master_chunk_id, step=0)[0]
fwd_wait_handles_warmup = None
# Run warmup forward passes
......@@ -627,10 +624,10 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk = None
fwd_wait_handles_send = None
for i in range(schedule['interleaved_forward'][rank]):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
for req, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0)
......@@ -659,14 +656,16 @@ def forward_backward_pipelining_with_cutinhalf(
master_cur_microbatch += 1
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
input_tensor_slave_chunk = output_tensor
input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
......@@ -678,15 +677,17 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles_warmup is not None:
for req in fwd_wait_handles_warmup:
req.wait()
for req, req_handle in fwd_wait_handles_warmup.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
......@@ -733,17 +734,21 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_send = output_tensor
fwd_wait_handles_send = send_forward(
output_tensor_send, tensor_shape, config, master_chunk_id, async_op=True)
else:
# custom_backward requires output_tensor.numel() == 1
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
for req, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
# Run 1b1w1f stages for slave chunk
bwd_wait_handles = None
for _ in range(schedule['1b1w1f'][rank]):
WeightGradStore.start_decouple()
# WeightGradStore.start_decouple()
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
......@@ -752,11 +757,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
WeightGradStore.end_decouple()
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
# WeightGradStore.end_decouple()
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
......@@ -765,22 +766,28 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_wait_handles_send is not None:
for req in fwd_wait_handles_send:
req.wait()
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
# If asynchronous, the memory will rise.
input_tensor_slave_chunk, recv_forward_handle = recv_forward(
tensor_shape, config, slave_chunk_id)
# 1w: Weight Grad Compute
WeightGradStore.pop()
# WeightGradStore.pop()
if recv_forward_handle is not None:
for req in recv_forward_handle:
req.wait()
for req, handle in recv_forward_handle.items():
if handle is not None:
handle.wait()
recv_forward_handle = None
# 1F: Forward pass
......@@ -816,7 +823,7 @@ def forward_backward_pipelining_with_cutinhalf(
# Run overlaping f&bw stages
fwd_model_chunk_id = master_chunk_id
bwd_model_chunk_id = slave_chunk_id
for _ in range(schedule['overlap'][rank] + schedule['1b1overlap'][rank] + schedule['interleaved_backward'][rank]):
for step_id in range(schedule['overlap'][rank] + schedule['1b1overlap'][rank] + schedule['interleaved_backward'][rank]):
only_bwd = False
if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max:
only_bwd = True
......@@ -853,24 +860,37 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor
input_tensor = output_tensor.detach()
input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if firstFB_no_overlp_handle is not None:
for req in firstFB_no_overlp_handle:
req.wait()
for req, req_handle in firstFB_no_overlp_handle.items():
if req_handle is not None:
req_handle.wait()
firstFB_no_overlp_handle = None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
......@@ -883,8 +903,9 @@ def forward_backward_pipelining_with_cutinhalf(
)
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
for req, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
......@@ -896,21 +917,15 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: # 同步上个阶段最后一个slave前向send
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
# only run backward
else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
......@@ -951,26 +966,28 @@ def forward_backward_pipelining_with_cutinhalf(
for i in range(pp_size):
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req in bwd_wait_handles_recv:
req.wait()
for req, req_handle in bwd_wait_handles_recv.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
if not args.dualpipe_no_dw_detach:
WeightGradStore.start_decouple()
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.start_decouple()
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if not args.dualpipe_no_dw_detach:
WeightGradStore.end_decouple()
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.end_decouple()
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
......@@ -988,18 +1005,19 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
WeightGradStore.flush_chunk_grad()
if i >= schedule['cooldown'][rank][0] - 1:
WeightGradStore.pop_single()
# WeightGradStore.flush_chunk_grad()
# if i >= schedule['cooldown'][rank][0] - 1:
# WeightGradStore.pop_single()
for _ in range(schedule['cooldown'][rank][2] - 1):
WeightGradStore.pop_single()
# for _ in range(schedule['cooldown'][rank][2] - 1):
# WeightGradStore.pop_single()
assert WeightGradStore.weight_grad_queue.empty()
# assert WeightGradStore.weight_grad_queue.empty()
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if config.finalize_model_grads_func is not None and not forward_only:
......
......@@ -2,7 +2,8 @@ from typing import Any, Optional
from torch import Tensor
from megatron.core import tensor_parallel
from megatron.training import get_args
from megatron.core import tensor_parallel, parallel_state
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import (
deprecate_inference_params,
......@@ -11,6 +12,176 @@ from megatron.core.utils import (
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from megatron.core.transformer.transformer_config import TransformerConfig
def get_transformer_layer_offset(config: TransformerConfig):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
args = get_args()
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if not parallel_state.is_inside_encoder():
pp_decoder_start = parallel_state.get_pipeline_model_parallel_decoder_start()
if pp_decoder_start is not None:
pipeline_rank = pipeline_rank - pp_decoder_start
if config.pipeline_model_parallel_size > 1:
if (
config.num_layers_in_first_pipeline_stage is not None
or config.num_layers_in_last_pipeline_stage is not None
):
# Calculate number of pipeline stages to distribute the remaining Transformer
# layers after deducting the Transformer layers in the first or the last stages
middle_pipeline_stages = config.pipeline_model_parallel_size
if args.schedule_method == 'dualpipev':
middle_pipeline_stages *= 2
middle_pipeline_stages -= sum(
[
1 if x is not None else 0
for x in (
config.num_layers_in_first_pipeline_stage,
config.num_layers_in_last_pipeline_stage,
)
]
)
# Calculate layers to distribute in each pipeline stage. If the
# num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
# are not set, we will not enable uneven pipeline. All layers will be treated
# as middle layers.
num_layers_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage
)
num_layers_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage
)
middle_num_layers = (
config.num_layers
- num_layers_in_first_pipeline_stage
- num_layers_in_last_pipeline_stage
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
# Calculate number of layers in each virtual model chunk
# If the num_layers_in_first_pipeline_stage and
# num_layers_in_last_pipeline_stage are not set, all pipeline stages
# will be treated as middle pipeline stages in the calculation
num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (
0
if config.num_layers_in_first_pipeline_stage is None
else config.num_layers_in_first_pipeline_stage // vp_size
)
num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (
0
if config.num_layers_in_last_pipeline_stage is None
else config.num_layers_in_last_pipeline_stage // vp_size
)
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = (
middle_num_layers // vp_size
)
# First stage + middle stage + last stage
total_virtual_chunks = (
num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
+ num_layers_per_virtual_model_chunk_in_last_pipeline_stage
)
# Calculate the layer offset with interleaved uneven pipeline parallelism
if pipeline_rank == 0:
offset = vp_rank * total_virtual_chunks
else:
offset = (
vp_rank * total_virtual_chunks
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage
+ (pipeline_rank - 1)
* (
num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
// middle_pipeline_stages
)
)
else:
if middle_pipeline_stages > 0:
num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages
else:
num_layers_per_pipeline_rank = 0
middle_pipeline_rank = (
pipeline_rank
if config.num_layers_in_first_pipeline_stage is None
else pipeline_rank - 1
)
if not getattr(args, 'dualpipev_first_chunk', True):
middle_pipeline_rank = (
config.pipeline_model_parallel_size
if config.num_layers_in_first_pipeline_stage is None
else config.pipeline_model_parallel_size - 1
) + (config.pipeline_model_parallel_size - (pipeline_rank + 1))
if getattr(args, 'dualpipev_first_chunk', True) and pipeline_rank == 0:
offset = 0
else:
offset = (
middle_pipeline_rank * num_layers_per_pipeline_rank
) + num_layers_in_first_pipeline_stage
else:
num_layers = config.num_layers
# Increase the number of layers by one if we include the embedding (loss)
# layer into pipeline parallelism partition and placement
if config.account_for_embedding_in_pipeline_split:
num_layers += 1
if config.account_for_loss_in_pipeline_split:
num_layers += 1
num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
if args.schedule_method == 'dualpipev':
num_layers_per_pipeline_rank = num_layers_per_pipeline_rank // 2
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (
pipeline_rank * num_layers_per_virtual_rank
)
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage()
):
offset -= 1
else:
if getattr(args, 'dualpipev_first_chunk', True):
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = num_layers - (pipeline_rank + 1) * num_layers_per_pipeline_rank
# Reduce the offset of embedding layer from the total layer number
if (
config.account_for_embedding_in_pipeline_split
and not parallel_state.is_pipeline_first_stage()
):
offset -= 1
else:
offset = 0
return offset
class TransformerLayer(MegatronCoreTransformerLayer):
......
import gc
import sys
from functools import wraps
import torch.distributed
import torch
......@@ -53,6 +54,29 @@ from megatron.training.training import (
stimer = StragglerDetector()
def build_train_valid_test_data_iterators_wrapper(build_train_valid_test_data_iterators_func):
@wraps(build_train_valid_test_data_iterators_func)
def wrapper(train_valid_test_dataset_provider):
args = get_args()
if args.schedule_method == 'dualpipev':
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for _ in range(2):
iterators = build_train_valid_test_data_iterators_func(train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators_func(
train_valid_test_dataset_provider)
return train_data_iterator, valid_data_iterator, test_data_iterator
return wrapper
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
......
import torch
from megatron.training import get_args
from megatron.core import mpu
def get_batch_on_this_tp_rank(data_iterator):
args = get_args()
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())
if mpu.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
batch = {
'tokens': data["tokens"].cuda(non_blocking = True),
'labels': data["labels"].cuda(non_blocking = True),
'loss_mask': data["loss_mask"].cuda(non_blocking = True),
'attention_mask': None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking = True),
'position_ids': data["position_ids"].cuda(non_blocking = True)
}
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
if args.schedule_method == "dualpipev":
_broadcast(batch['loss_mask'])
_broadcast(batch['labels'])
elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(batch['tokens'])
_broadcast(batch['position_ids'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
else:
tokens=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
labels=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
loss_mask=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.float32 , device = torch.cuda.current_device())
if args.create_attention_mask_in_dataloader:
attention_mask=torch.empty(
(args.micro_batch_size,1,args.seq_length,args.seq_length), dtype = torch.bool , device = torch.cuda.current_device()
)
else:
attention_mask=None
position_ids=torch.empty((args.micro_batch_size,args.seq_length), dtype = torch.int64 , device = torch.cuda.current_device())
if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
if args.schedule_method == "dualpipev":
_broadcast(loss_mask)
_broadcast(labels)
else:
labels=None
loss_mask=None
elif mpu.is_pipeline_last_stage():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if args.mtp_num_layers is not None:
_broadcast(tokens)
_broadcast(position_ids)
else:
tokens=None
position_ids=None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
return batch
\ No newline at end of file
......@@ -136,7 +136,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment