Commit 12b56c98 authored by dongcl's avatar dongcl
Browse files

support a2a overlap

parent 8551c38e
This diff is collapsed.
...@@ -10,6 +10,7 @@ from torch import Tensor ...@@ -10,6 +10,7 @@ from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear from dcu_megatron.core.tensor_parallel import FluxColumnParallelLinear
...@@ -45,7 +46,72 @@ def gpt_model_init_wrapper(fn): ...@@ -45,7 +46,72 @@ def gpt_model_init_wrapper(fn):
return wrapper return wrapper
def gpt_model_forward( class GPTModel(MegatronCoreGPTModel):
"""
patch megatron GPTModel
"""
def get_transformer_callables_by_layer(self, layer_number: int):
"""
Get the callables for the layer at the given transformer layer number.
"""
return self.decoder.get_layer_callables(layer_number)
def build_schedule_plan(
self,
input_ids: Tensor,
position_ids: Tensor,
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None,
):
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from .fine_grained_schedule import build_model_chunk_schedule_plan
return build_model_chunk_schedule_plan(
self,
input_ids,
position_ids,
attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs,
runtime_gather_output=runtime_gather_output,
)
def forward(
self, self,
input_ids: Tensor, input_ids: Tensor,
position_ids: Tensor, position_ids: Tensor,
...@@ -57,7 +123,7 @@ def gpt_model_forward( ...@@ -57,7 +123,7 @@ def gpt_model_forward(
extra_block_kwargs: dict = None, extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None, runtime_gather_output: Optional[bool] = None,
loss_mask: Optional[Tensor] = None, loss_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors """Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post through the embedding layer, and then the decoder and finally into the post
processing layer (optional). processing layer (optional).
......
This diff is collapsed.
from megatron.core.transformer.moe.token_dispatcher import _DeepepManager as MegatronCoreDeepepManager
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
# Permutation 1: input to AlltoAll input
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_1", tokens_per_expert
)
self.hidden_shape_before_permute = hidden_states.shape
(
permutated_local_input_tokens,
permuted_probs,
self.reversed_local_input_permutation_mapping,
) = permute(
hidden_states,
routing_map,
probs=probs,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
)
global_input_tokens = all_to_all(
self.ep_group, permutated_local_input_tokens, self.output_splits, self.input_splits
)
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits
)
if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
if self.tp_size > 1:
if self.output_splits_tp is None:
output_split_sizes = None
else:
output_split_sizes = self.output_splits_tp.tolist()
global_input_tokens = gather_from_sequence_parallel_region(
global_input_tokens, group=self.tp_group, output_split_sizes=output_split_sizes
)
global_probs = gather_from_sequence_parallel_region(
global_probs, group=self.tp_group, output_split_sizes=output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_permutation_2", tokens_per_expert
)
if self.num_local_experts > 1:
if self.drop_and_pad:
global_input_tokens = (
global_input_tokens.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_input_tokens.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
global_probs = (
global_probs.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_probs.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
global_input_tokens, global_probs = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert.ravel(),
self.sort_input_by_local_experts,
probs=global_probs,
fused=self.config.moe_permute_fusion,
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert, global_probs
class _DeepepManager(MegatronCoreDeepepManager):
"""
patch megatron _DeepepManager. async
"""
def dispatch(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
# DeepEP only supports float32 probs
if self.token_probs.dtype != torch.float32:
if self.token_probs.dtype in [torch.bfloat16, torch.float16]:
print("DeepEP only supports float32 probs, please set --moe-router-dtype=fp32")
self.token_probs = self.token_probs.float() # downcast or upcast
hidden_states, dispatched_indices, dispatched_probs, num_tokens_per_expert, handle = (
fused_dispatch(
hidden_states,
self.token_indices,
self.token_probs,
self.num_experts,
self.group,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
)
self.handle = handle
self.tokens_per_expert = num_tokens_per_expert
self.dispatched_indices = dispatched_indices
self.dispatched_probs = dispatched_probs
return hidden_states
def combine(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
hidden_states, _ = fused_combine(
hidden_states,
self.group,
self.handle,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
# Release the handle after combine operation
self.handle = None
return hidden_states
class MoEFlexTokenDispatcher(MoETokenDispatcher):
"""
Flex token dispatcher using DeepEP.
"""
def dispatch_preprocess(
self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
hidden_states (torch.Tensor): Input hidden states to be processed
routing_map (torch.Tensor): Map indicating which expert each token should be routed to
probs (torch.Tensor): Routing probabilities for each token-expert pair
Returns:
Tuple containing:
- torch.Tensor: Reshaped hidden states
- torch.Tensor: Token probabilities from the communication manager
- None: Placeholder for compatibility
"""
self.hidden_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Initialize metadata
routing_map, probs = self._initialize_metadata(routing_map, probs)
self._comm_manager.setup_metadata(routing_map, probs)
return hidden_states, self._comm_manager.token_probs, None
def dispatch_all_to_all(
self,
hidden_states: torch.Tensor,
probs: torch.Tensor = None,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return (
self._comm_manager.dispatch(hidden_states, async_finish, allocate_on_comm_stream),
self._comm_manager.dispatched_probs,
)
def dispatch_postprocess(self, hidden_states: torch.Tensor):
"""
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens
per expert, and returns the processed data ready for expert processing.
"""
global_input_tokens, permuted_probs = (
self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states)
)
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()
return global_input_tokens, tokens_per_expert, permuted_probs
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states, _, _ = self.dispatch_preprocess(hidden_states, routing_map, probs)
hidden_states, _ = self.dispatch_all_to_all(
hidden_states, async_finish=False, allocate_on_comm_stream=False
)
global_input_tokens, tokens_per_expert, permuted_probs = self.dispatch_postprocess(
hidden_states
)
return global_input_tokens, tokens_per_expert, permuted_probs
def combine_preprocess(self, hidden_states: torch.Tensor):
"""
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing
by using the communication manager's restoration function.
"""
hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states)
return hidden_states
def combine_all_to_all(
self,
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to combine tokens after expert processing.
"""
return self._comm_manager.combine(hidden_states, async_finish, allocate_on_comm_stream)
def combine_postprocess(self, hidden_states: torch.Tensor):
"""
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape.
"""
return hidden_states.view(self.hidden_shape)
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverses the token permutation process to restore the original token order.
This method implements the token unpermutation process in three steps:
1. Pre-process the hidden states to restore their original ordering
2. Perform all-to-all communication to combine tokens
3. Post-process the combined tokens to match the original input shape
"""
assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states)
hidden_states = self.combine_all_to_all(hidden_states, False, False)
hidden_states = self.combine_postprocess(hidden_states)
return hidden_states, None
from functools import wraps from functools import wraps
from megatron.core.transformer.transformer_block import TransformerBlock as MegatronCoreTransformerBlock
def transformer_block_init_wrapper(fn): def transformer_block_init_wrapper(fn):
@wraps(fn) @wraps(fn)
...@@ -13,3 +14,22 @@ def transformer_block_init_wrapper(fn): ...@@ -13,3 +14,22 @@ def transformer_block_init_wrapper(fn):
self.final_layernorm = None self.final_layernorm = None
return wrapper return wrapper
class TransformerBlock(MegatronCoreTransformerBlock):
def __init__(
self, *args, **kwargs
):
super().__init__(*args, **kwargs)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None
def get_layer_callables(self, layer_number: int):
"""
Get the callables for the layer at the given layer number.
"""
return self.layers[layer_number].get_submodule_callables()
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
class TransformerLayer(MegatronCoreTransformerLayer):
def _submodule_attn_router_forward(
self,
hidden_states,
attention_mask=None,
inference_params=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
state=None,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states, _ = self._forward_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
inference_params=inference_params,
)
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
local_tokens, probs, tokens_per_expert = self.mlp.token_dispatcher.dispatch_preprocess(
pre_mlp_layernorm_output, routing_map, probs
)
return (local_tokens, probs, hidden_states, pre_mlp_layernorm_output, tokens_per_expert)
def _submodule_dispatch_forward(self, local_tokens, probs, state=None):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
token_dispatcher = self.mlp.token_dispatcher
if self.is_deepep:
token_dispatcher._comm_manager.token_probs = probs
return token_dispatcher.dispatch_all_to_all(local_tokens, probs)
def _submodule_moe_forward(self, dispatched_tokens, probs=None, state=None):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output = None
token_dispatcher = self.mlp.token_dispatcher
if self.is_deepep:
token_dispatcher._comm_manager.dispatched_probs = state.dispatched_probs
dispatched_tokens, tokens_per_expert, permuted_probs = (
token_dispatcher.dispatch_postprocess(dispatched_tokens)
)
else:
dispatched_tokens, permuted_probs = token_dispatcher.dispatch_postprocess(
dispatched_tokens, probs
)
tokens_per_expert = state.tokens_per_expert
expert_output, mlp_bias = self.mlp.experts(
dispatched_tokens, tokens_per_expert, permuted_probs
)
assert mlp_bias is None, f"Bias is not supported in {token_dispatcher.__class__.__name__}"
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(state.pre_mlp_layernorm_output)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
return expert_output, shared_expert_output, mlp_bias
def _submodule_combine_forward(self, output, shared_expert_output=None, state=None):
residual = state.residual
token_dispatcher = self.mlp.token_dispatcher
output = token_dispatcher.combine_all_to_all(output)
output = token_dispatcher.combine_postprocess(output)
if shared_expert_output is not None:
output = output + shared_expert_output
mlp_output_with_bias = (output, None)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output
def _submodule_attn_router_dw(self):
self.self_attention.backward_dw()
def _submodule_mlp_dw(self):
self.mlp.backward_dw()
def _submodule_attn_router_postprocess(
self, node, local_tokens, probs, residual, pre_mlp_layernorm_output, tokens_per_expert
):
node.common_state.residual = node.detach(residual)
if self.mlp.use_shared_expert:
node.common_state.pre_mlp_layernorm_output = node.detach(pre_mlp_layernorm_output)
if not self.is_deepep:
node.common_state.tokens_per_expert = tokens_per_expert
return local_tokens, probs
def _submodule_dispatch_postprocess(self, node, dispatched_tokens, probs):
if self.is_deepep:
node.common_state.dispatched_probs = node.detach(probs)
return dispatched_tokens
else:
return dispatched_tokens, probs
def _submodule_mlp_postprocess(self, node, expert_output, shared_expert_output, mlp_bias):
assert mlp_bias is None
node.common_state.pre_mlp_layernorm_output = None
if shared_expert_output is None:
return expert_output
return expert_output, shared_expert_output
def _submodule_combine_postprocess(self, node, output):
cur_stream = torch.cuda.current_stream()
node.common_state.residual.record_stream(cur_stream)
node.common_state.residual = None
return output
def _submodule_attn_postprocess(self, node, hidden_states, context):
return hidden_states
def _submodule_dense_postprocess(self, node, hidden_states):
return hidden_states
def _submodule_not_implemented(self, *args):
raise NotImplementedError("This callable is not implemented.")
def get_submodule_callables(self, chunk_state):
"""
The forward callables take 2 parts of inputs:
1. The ScheduleNode object.
2. The input tensors.
"""
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.token_dispatcher import MoEFlexTokenDispatcher
self.is_moe = isinstance(self.mlp, MoELayer)
self.is_deepep = False
if self.is_moe:
self.is_deepep = isinstance(self.mlp.token_dispatcher, MoEFlexTokenDispatcher)
def get_func_with_default(func, default_func):
if self.is_moe:
return func
return default_func
def callable_wrapper(forward_func, postprocess_func, node, *args):
state = getattr(node, 'common_state', None)
callable_outputs = forward_func(*args, state=state)
if isinstance(callable_outputs, tuple):
outputs = postprocess_func(node, *callable_outputs)
else:
outputs = postprocess_func(node, callable_outputs)
return outputs
attn_func = get_func_with_default(
self._submodule_attn_router_forward, self._forward_attention
)
def attn_wrapper(hidden_states, state=None):
return attn_func(
hidden_states=hidden_states,
attention_mask=chunk_state.attention_mask,
attention_bias=chunk_state.attention_bias,
inference_params=chunk_state.inference_params,
packed_seq_params=chunk_state.packed_seq_params,
sequence_len_offset=chunk_state.sequence_len_offset,
rotary_pos_emb=chunk_state.rotary_pos_emb,
rotary_pos_cos=chunk_state.rotary_pos_cos,
rotary_pos_sin=chunk_state.rotary_pos_sin,
state=state,
)
attn_postprocess_func = get_func_with_default(
self._submodule_attn_router_postprocess, self._submodule_attn_postprocess
)
dispatch_func = get_func_with_default(
self._submodule_dispatch_forward, self._submodule_not_implemented
)
dispatch_postprocess_func = get_func_with_default(
self._submodule_dispatch_postprocess, self._submodule_not_implemented
)
mlp_func = get_func_with_default(self._submodule_moe_forward, self._forward_mlp)
mlp_postprocess_func = get_func_with_default(
self._submodule_mlp_postprocess, self._submodule_dense_postprocess
)
combine_func = get_func_with_default(
self._submodule_combine_forward, self._submodule_not_implemented
)
combine_postprocess_func = get_func_with_default(
self._submodule_combine_postprocess, self._submodule_not_implemented
)
attn_forward = partial(callable_wrapper, attn_wrapper, attn_postprocess_func)
dispatch_forward = partial(callable_wrapper, dispatch_func, dispatch_postprocess_func)
mlp_forward = partial(callable_wrapper, mlp_func, mlp_postprocess_func)
combine_forward = partial(callable_wrapper, combine_func, combine_postprocess_func)
callables = TransformerLayerSubmoduleCallables(
attention=SubmoduleCallables(forward=attn_forward, dw=self._submodule_attn_router_dw),
dispatch=SubmoduleCallables(forward=dispatch_forward),
mlp=SubmoduleCallables(forward=mlp_forward, dw=self._submodule_mlp_dw),
combine=SubmoduleCallables(forward=combine_forward),
is_moe=self.is_moe,
is_deepep=self.is_deepep,
)
return callables
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment