Commit 4863ddcf authored by dongcl's avatar dongcl
Browse files

add tokens_per_expert to common_state

parent 6a579b17
...@@ -7,16 +7,15 @@ import torch ...@@ -7,16 +7,15 @@ import torch
from torch import Tensor from torch import Tensor
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import transformer_layer from megatron.core.transformer import transformer_layer
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.inference.contexts import BaseInferenceContext
from dcu_megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllPerBatchState
from dcu_megatron.core.pipeline_parallel.combined_1f1b import ( from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan, AbstractSchedulePlan,
FakeScheduleNode,
FreeInputsMemoryStrategy,
NoOpMemoryStrategy,
ScheduleNode, ScheduleNode,
get_com_stream, get_com_stream,
get_comp_stream, get_comp_stream,
...@@ -776,10 +775,12 @@ def build_model_chunk_schedule_plan( ...@@ -776,10 +775,12 @@ def build_model_chunk_schedule_plan(
attention_mask: Tensor, attention_mask: Tensor,
decoder_input: Tensor = None, decoder_input: Tensor = None,
labels: Tensor = None, labels: Tensor = None,
inference_params=None, inference_context: BaseInferenceContext = None,
packed_seq_params=None, packed_seq_params=None,
extra_block_kwargs=None, extra_block_kwargs=None,
runtime_gather_output: Optional[bool] = None, runtime_gather_output: Optional[bool] = None,
inference_params=None,
loss_mask=None,
): ):
"""Builds a schedule plan for a model chunk. """Builds a schedule plan for a model chunk.
...@@ -797,6 +798,7 @@ def build_model_chunk_schedule_plan( ...@@ -797,6 +798,7 @@ def build_model_chunk_schedule_plan(
packed_seq_params: Parameters for packed sequences. packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks. extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime. runtime_gather_output: Whether to gather output at runtime.
loss_mask: Loss mask
Returns: Returns:
The model chunk schedule plan. The model chunk schedule plan.
...@@ -812,10 +814,12 @@ def build_model_chunk_schedule_plan( ...@@ -812,10 +814,12 @@ def build_model_chunk_schedule_plan(
state.attention_mask = attention_mask state.attention_mask = attention_mask
state.decoder_input = decoder_input state.decoder_input = decoder_input
state.labels = labels state.labels = labels
state.inference_params = inference_params state.inference_context = inference_context
state.packed_seq_params = packed_seq_params state.packed_seq_params = packed_seq_params
state.extra_block_kwargs = extra_block_kwargs state.extra_block_kwargs = extra_block_kwargs
state.runtime_gather_output = runtime_gather_output state.runtime_gather_output = runtime_gather_output
state.inference_params = inference_params
state.loss_mask = loss_mask
state.context = None state.context = None
state.context_mask = None state.context_mask = None
state.attention_bias = None state.attention_bias = None
......
...@@ -117,8 +117,10 @@ def forward_backward_pipelining_with_interleaving( ...@@ -117,8 +117,10 @@ def forward_backward_pipelining_with_interleaving(
config = get_model_config(model[0]) config = get_model_config(model[0])
set_streams() set_streams()
if not forward_only: if config.combined_1f1b and not forward_only:
forward_step_func = wrap_forward_func(config, forward_step_func) # in combined_1f1b, we need to wrap the forward_step_func
# to return a schedule plan instead of the forward output tensor
forward_step_func = wrap_forward_func(forward_step_func)
if config.overlap_p2p_comm and config.batch_p2p_comm: if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
......
from typing import Any, Optional from typing import Any, Optional
from functools import partial
import torch
from torch import Tensor from torch import Tensor
from megatron.training import get_args from megatron.training import get_args
from megatron.core import tensor_parallel, parallel_state from megatron.core import parallel_state
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import ( from megatron.core.utils import make_viewless_tensor
deprecate_inference_params,
make_viewless_tensor,
nvtx_range_pop,
nvtx_range_push,
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[int] = None): def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[int] = None):
"""Get the index offset of current pipeline stage, given the level of pipelining.""" """Get the index offset of current pipeline stage, given the level of pipelining."""
...@@ -244,29 +241,33 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -244,29 +241,33 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual, residual,
context, context,
): ):
node.common_state.tokens_per_expert = tokens_per_expert
node.common_state.residual = node.detach(residual) node.common_state.residual = node.detach(residual)
if self.mlp.use_shared_expert: if self.mlp.use_shared_expert:
node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output) node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)
return tokens_per_expert, permutated_local_input_tokens, permuted_probs return permutated_local_input_tokens, permuted_probs
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs, state=None): def _submodule_dispatch_forward(self, permutated_local_input_tokens, permuted_probs, state=None):
""" """
Dispatches tokens to the appropriate experts based on the router output. Dispatches tokens to the appropriate experts based on the router output.
""" """
tokens_per_expert = state.tokens_per_expert
token_dispatcher = self.mlp.token_dispatcher token_dispatcher = self.mlp.token_dispatcher
tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs) tokens_per_expert, global_input_tokens, global_probs = token_dispatcher.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
return tokens_per_expert, global_input_tokens, global_probs return tokens_per_expert, global_input_tokens, global_probs
def _submodule_dispatch_postprocess(self, node, tokens_per_expert, global_input_tokens, global_probs): def _submodule_dispatch_postprocess(self, node, tokens_per_expert, global_input_tokens, global_probs):
return tokens_per_expert, global_input_tokens, global_probs node.common_state.tokens_per_expert = tokens_per_expert
return global_input_tokens, global_probs
def _submodule_moe_forward(self, tokens_per_expert, global_input_tokens, global_probs, state=None): def _submodule_moe_forward(self, global_input_tokens, global_probs, state=None):
""" """
Performs a forward pass for the MLP submodule, including both expert-based Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations. and optional shared-expert computations.
""" """
tokens_per_expert = state.tokens_per_expert
shared_expert_output = None shared_expert_output = None
token_dispatcher = self.mlp.token_dispatcher token_dispatcher = self.mlp.token_dispatcher
...@@ -275,7 +276,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -275,7 +276,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
) )
expert_output, mlp_bias = self.mlp.experts( expert_output, mlp_bias = self.mlp.experts(
dispatched_tokens, tokens_per_expert, permuted_probs dispatched_input, tokens_per_expert, permuted_probs
) )
assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}" assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap: if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
...@@ -371,7 +372,7 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -371,7 +372,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return attn_func( return attn_func(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=chunk_state.attention_mask, attention_mask=chunk_state.attention_mask,
content=chunk_state.context, context=chunk_state.context,
context_mask=chunk_state.context_mask, context_mask=chunk_state.context_mask,
rotary_pos_emb=chunk_state.rotary_pos_emb, rotary_pos_emb=chunk_state.rotary_pos_emb,
rotary_pos_cos=chunk_state.rotary_pos_cos, rotary_pos_cos=chunk_state.rotary_pos_cos,
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
from functools import partial
@dataclass @dataclass
class SubmoduleCallables: class SubmoduleCallables:
...@@ -9,10 +9,13 @@ class SubmoduleCallables: ...@@ -9,10 +9,13 @@ class SubmoduleCallables:
for a particular submodule. for a particular submodule.
""" """
forward: Optional[Callable] = None def raise_not_implemented(name: str):
backward: Optional[Callable] = None raise NotImplementedError(f"{name} not implemented.")
dgrad: Optional[Callable] = None
dw: Optional[Callable] = None forward: Optional[Callable] = partial(raise_not_implemented, "forward")
dw: Optional[Callable] = partial(raise_not_implemented, "dw")
is_moe: bool = False
is_deepep: bool = False
@dataclass @dataclass
...@@ -26,7 +29,13 @@ class TransformerLayerSubmoduleCallables: ...@@ -26,7 +29,13 @@ class TransformerLayerSubmoduleCallables:
dispatch: SubmoduleCallables dispatch: SubmoduleCallables
mlp: SubmoduleCallables mlp: SubmoduleCallables
combine: SubmoduleCallables combine: SubmoduleCallables
post_combine: SubmoduleCallables is_moe: bool = False
is_deepep: bool = False
def as_array(self): def as_array(self):
return [self.attention, self.dispatch, self.mlp, self.combine, self.post_combine] return [self.attention, self.dispatch, self.mlp, self.combine]
def __post_init__(self):
for submodule in self.as_array():
submodule.is_moe = self.is_moe
submodule.is_deepep = self.is_deepep
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment