Commit 69add73b authored by dongcl's avatar dongcl
Browse files

dualpipev support moe a2a overlap

parent 62f16817
__pycache__
*.bak
*.log
......@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch(
'megatron.training.training.evaluate', evaluate)
if (
args.schedule_method == "interleaved_1f1b"
and args.combined_1f1b
):
if args.combined_1f1b:
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from dcu_megatron.core.transformer.transformer_layer import TransformerLayer
from dcu_megatron.core.models.gpt.gpt_model import GPTModel
from dcu_megatron.core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from dcu_megatron.core.extensions.transformer_engine import (
_get_extra_te_kwargs_wrapper,
TELinear,
......@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature):
from dcu_megatron.core.transformer.moe.experts import TEGroupedMLP
from dcu_megatron.core.transformer.moe.moe_layer import MoELayer
# num_warmup_microbatches + 1
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
patch_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
patch_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan,
create_dummy=True)
# backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
patch_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper,
apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
patch_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
patch_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear)
TEColumnParallelLinear.__bases__ = (TELinear,)
TERowParallelLinear.__bases__ = (TELinear,)
if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear
from ..core.extensions.transformer_engine import TEGroupedLinear
from dcu_megatron.core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
patch_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear)
TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
patch_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
patch_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
patch_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
TEGroupedMLP.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw,
create_dummy=True)
create_dummy=True)
if args.schedule_method == "interleaved_1f1b":
from dcu_megatron.core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
......@@ -6,7 +6,6 @@ from typing import Optional
import torch
from torch import Tensor
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
......@@ -720,7 +719,7 @@ def schedule_chunk_1f1b(
if f_schedule_plan is not None and post_forward is not None:
with f_context:
f_schedule_plan.wait_current_stream()
post_forward(None if parallel_state.is_pipeline_last_stage(ignore_virtual=False) else f_input)
post_forward(f_input)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if b_schedule_plan is not None and post_backward is not None:
......
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunk_id):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunk_id
def get_dualpipe_chunk():
global _DUALPIPE_CHUNK
if _DUALPIPE_CHUNK is not None:
return _DUALPIPE_CHUNK
else:
raise AssertionError("_DUALPIPE_CHUNK is None")
......@@ -7,6 +7,7 @@ import torch
from torch import Tensor
from torch.autograd.variable import Variable
from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel
......@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def make_viewless(e):
"""make_viewless util func"""
......@@ -432,7 +435,13 @@ def forward_backward_step(
if f_model:
with f_context:
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
args = get_args()
is_last_stage = False
if args.schedule_method == "dualpipev":
is_last_stage = parallel_state.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
else:
is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False)
if is_last_stage:
if not collect_non_loss_data:
loss_node = ScheduleNode(
loss_func,
......
......@@ -20,7 +20,7 @@ from megatron.training.utils import (
reduce_max_stat_across_model_parallel_group
)
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_dualpipe_chunk
from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def dualpipev_fp16forward(self, *inputs, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment