Commit 4863ddcf authored by dongcl's avatar dongcl
Browse files

add tokens_per_expert to common_state

parent 6a579b17
......@@ -7,16 +7,15 @@ import torch
from torch import Tensor
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import transformer_layer
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.inference.contexts import BaseInferenceContext
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan,
FakeScheduleNode,
FreeInputsMemoryStrategy,
NoOpMemoryStrategy,
ScheduleNode,
get_com_stream,
get_comp_stream,
......@@ -776,10 +775,12 @@ def build_model_chunk_schedule_plan(
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params=None,
inference_context: BaseInferenceContext = None,
packed_seq_params=None,
extra_block_kwargs=None,
runtime_gather_output: Optional[bool] = None,
inference_params=None,
loss_mask=None,
):
"""Builds a schedule plan for a model chunk.
......@@ -797,6 +798,7 @@ def build_model_chunk_schedule_plan(
packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime.
loss_mask: Loss mask
Returns:
The model chunk schedule plan.
......@@ -812,10 +814,12 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask
state.decoder_input = decoder_input
state.labels = labels
state.inference_params = inference_params
state.inference_context = inference_context
state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None
state.context_mask = None
state.attention_bias = None
......
......@@ -117,8 +117,10 @@ def forward_backward_pipelining_with_interleaving(
config = get_model_config(model[0])
set_streams()
if not forward_only:
forward_step_func = wrap_forward_func(config, forward_step_func)
if config.combined_1f1b and not forward_only:
# in combined_1f1b, we need to wrap the forward_step_func
# to return a schedule plan instead of the forward output tensor
forward_step_func = wrap_forward_func(forward_step_func)
if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
......
from typing import Any, Optional
from functools import partial
import torch
from torch import Tensor
from megatron.training import get_args
from megatron.core import tensor_parallel, parallel_state
from megatron.core import parallel_state
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import (
deprecate_inference_params,
make_viewless_tensor,
nvtx_range_pop,
nvtx_range_push,
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.utils import make_viewless_tensor
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[int] = None):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
......@@ -244,29 +241,33 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual,
context,
):
node.common_state.tokens_per_expert = tokens_per_expert
node.common_state.residual = node.detach(residual)
if self.mlp.use_shared_expert:
node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
return permutated_local_input_tokens, permuted_probs
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs, state=None):
def _submodule_dispatch_forward(self, permutated_local_input_tokens, permuted_probs, state=None):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
tokens_per_expert = state.tokens_per_expert
token_dispatcher = self.mlp.token_dispatcher
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
return tokens_per_expert, global_input_tokens, global_probs
def _submodule_dispatch_postprocess(self, node, tokens_per_expert, global_input_tokens, global_probs):
return tokens_per_expert, global_input_tokens, global_probs
node.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, state=None):
def _submodule_moe_forward(self, global_input_tokens, global_probs, state=None):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
tokens_per_expert = state.tokens_per_expert
shared_expert_output = None
token_dispatcher = self.mlp.token_dispatcher
......@@ -275,7 +276,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
expert_output, mlp_bias = self.mlp.experts(
dispatched_tokens, tokens_per_expert, permuted_probs
dispatched_input, tokens_per_expert, permuted_probs
)
assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
......@@ -371,7 +372,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return attn_func(
hidden_states=hidden_states,
attention_mask=chunk_state.attention_mask,
content=chunk_state.context,
context=chunk_state.context,
context_mask=chunk_state.context_mask,
rotary_pos_emb=chunk_state.rotary_pos_emb,
rotary_pos_cos=chunk_state.rotary_pos_cos,
......
from dataclasses import dataclass
from typing import Callable, Optional
from functools import partial
@dataclass
class SubmoduleCallables:
......@@ -9,10 +9,13 @@ class SubmoduleCallables:
for a particular submodule.
"""
forward: Optional[Callable] = None
backward: Optional[Callable] = None
dgrad: Optional[Callable] = None
dw: Optional[Callable] = None
def raise_not_implemented(name: str):
raise NotImplementedError(f"{name} not implemented.")
forward: Optional[Callable] = partial(raise_not_implemented, "forward")
dw: Optional[Callable] = partial(raise_not_implemented, "dw")
is_moe: bool = False
is_deepep: bool = False
@dataclass
......@@ -26,7 +29,13 @@ class TransformerLayerSubmoduleCallables:
dispatch: SubmoduleCallables
mlp: SubmoduleCallables
combine: SubmoduleCallables
post_combine: SubmoduleCallables
is_moe: bool = False
is_deepep: bool = False
def as_array(self):
return [self.attention, self.dispatch, self.mlp, self.combine, self.post_combine]
return [self.attention, self.dispatch, self.mlp, self.combine]
def __post_init__(self):
for submodule in self.as_array():
submodule.is_moe = self.is_moe
submodule.is_deepep = self.is_deepep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment