Commit 69add73b authored by dongcl's avatar dongcl
Browse files

dualpipev support moe a2a overlap

parent 62f16817
__pycache__ __pycache__
*.bak *.bak
*.log
...@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature): ...@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch( patch_manager.register_patch(
'megatron.training.training.evaluate', evaluate) 'megatron.training.training.evaluate', evaluate)
if ( if args.combined_1f1b:
args.schedule_method == "interleaved_1f1b"
and args.combined_1f1b
):
from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from dcu_megatron.core.transformer.transformer_layer import TransformerLayer from dcu_megatron.core.transformer.transformer_layer import TransformerLayer
from dcu_megatron.core.models.gpt.gpt_model import GPTModel from dcu_megatron.core.models.gpt.gpt_model import GPTModel
from dcu_megatron.core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
from dcu_megatron.core.extensions.transformer_engine import ( from dcu_megatron.core.extensions.transformer_engine import (
_get_extra_te_kwargs_wrapper, _get_extra_te_kwargs_wrapper,
TELinear, TELinear,
...@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature): ...@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature):
from dcu_megatron.core.transformer.moe.experts import TEGroupedMLP from dcu_megatron.core.transformer.moe.experts import TEGroupedMLP
from dcu_megatron.core.transformer.moe.moe_layer import MoELayer from dcu_megatron.core.transformer.moe.moe_layer import MoELayer
# num_warmup_microbatches + 1 patch_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patches_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
patches_manager.register_patch('megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher',
MoEAlltoAllTokenDispatcher) MoEAlltoAllTokenDispatcher)
patches_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer', patch_manager.register_patch('megatron.core.transformer.transformer_layer.TransformerLayer',
TransformerLayer) TransformerLayer)
patches_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan', patch_manager.register_patch('megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan',
GPTModel.build_schedule_plan, GPTModel.build_schedule_plan,
create_dummy=True) create_dummy=True)
# backward_dw # backward_dw
patches_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs', patch_manager.register_patch('megatron.core.extensions.transformer_engine._get_extra_te_kwargs',
_get_extra_te_kwargs_wrapper, _get_extra_te_kwargs_wrapper,
apply_wrapper=True) apply_wrapper=True)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear', patch_manager.register_patch('megatron.core.extensions.transformer_engine.TELinear',
TELinear) TELinear)
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear', patch_manager.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
TELayerNormColumnParallelLinear) TELayerNormColumnParallelLinear)
TEColumnParallelLinear.__bases__ = (TELinear,) TEColumnParallelLinear.__bases__ = (TELinear,)
TERowParallelLinear.__bases__ = (TELinear,) TERowParallelLinear.__bases__ = (TELinear,)
if is_te_min_version("1.9.0.dev0"): if is_te_min_version("1.9.0.dev0"):
from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear, TERowParallelGroupedLinear
from ..core.extensions.transformer_engine import TEGroupedLinear from dcu_megatron.core.extensions.transformer_engine import TEGroupedLinear
patches_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear', patch_manager.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
TEGroupedLinear) TEGroupedLinear)
TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,) TEColumnParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,) TERowParallelGroupedLinear.__bases__ = (TEGroupedLinear,)
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw', patch_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw, MLASelfAttention.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw', patch_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw, MLP.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw', patch_manager.register_patch('megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw',
TEGroupedMLP.backward_dw, TEGroupedMLP.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw', patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw',
MoELayer.backward_dw, MoELayer.backward_dw,
create_dummy=True) create_dummy=True)
if args.schedule_method == "interleaved_1f1b":
from dcu_megatron.core.pipeline_parallel.schedules import get_pp_rank_microbatches, forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches',
get_pp_rank_microbatches)
# a2a_overlap
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving',
forward_backward_pipelining_with_interleaving)
...@@ -6,7 +6,6 @@ from typing import Optional ...@@ -6,7 +6,6 @@ from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.inference.contexts import BaseInferenceContext
...@@ -720,7 +719,7 @@ def schedule_chunk_1f1b( ...@@ -720,7 +719,7 @@ def schedule_chunk_1f1b(
if f_schedule_plan is not None and post_forward is not None: if f_schedule_plan is not None and post_forward is not None:
with f_context: with f_context:
f_schedule_plan.wait_current_stream() f_schedule_plan.wait_current_stream()
post_forward(None if parallel_state.is_pipeline_last_stage(ignore_virtual=False) else f_input) post_forward(f_input)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch # pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if b_schedule_plan is not None and post_backward is not None: if b_schedule_plan is not None and post_backward is not None:
......
_DUALPIPE_CHUNK = None
def set_dualpipe_chunk(chunk_id):
"""set_dualpipe_chunk for fp16forward patch"""
global _DUALPIPE_CHUNK
_DUALPIPE_CHUNK = chunk_id
def get_dualpipe_chunk():
global _DUALPIPE_CHUNK
if _DUALPIPE_CHUNK is not None:
return _DUALPIPE_CHUNK
else:
raise AssertionError("_DUALPIPE_CHUNK is None")
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel from megatron.core.distributed import DistributedDataParallel
...@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler ...@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor from megatron.core.utils import get_attr_wrapped_model, make_viewless_tensor
from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def make_viewless(e): def make_viewless(e):
"""make_viewless util func""" """make_viewless util func"""
...@@ -432,7 +435,13 @@ def forward_backward_step( ...@@ -432,7 +435,13 @@ def forward_backward_step(
if f_model: if f_model:
with f_context: with f_context:
num_tokens = torch.tensor(0, dtype=torch.int) num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): args = get_args()
is_last_stage = False
if args.schedule_method == "dualpipev":
is_last_stage = parallel_state.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
else:
is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False)
if is_last_stage:
if not collect_non_loss_data: if not collect_non_loss_data:
loss_node = ScheduleNode( loss_node = ScheduleNode(
loss_func, loss_func,
......
...@@ -20,7 +20,7 @@ from megatron.training.utils import ( ...@@ -20,7 +20,7 @@ from megatron.training.utils import (
reduce_max_stat_across_model_parallel_group reduce_max_stat_across_model_parallel_group
) )
from dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_dualpipe_chunk from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def dualpipev_fp16forward(self, *inputs, **kwargs): def dualpipev_fp16forward(self, *inputs, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment