Commit 32ee381a authored by dongcl's avatar dongcl
Browse files

a2a overlap

parent 12b56c98
......@@ -9,6 +9,7 @@ from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
......@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
):
"""Builds a computation schedule plan for the model.
......@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask,
decoder_input=decoder_input,
labels=labels,
inference_params=inference_params,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs,
runtime_gather_output=runtime_gather_output,
inference_params=inference_params,
loss_mask=loss_mask,
)
def forward(
......@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask: Tensor,
decoder_input: Tensor = None,
labels: Tensor = None,
inference_params: InferenceParams = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
......@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding.
if decoder_input is not None:
pass
......@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel):
rotary_pos_cos = None
rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params:
if not self.training and self.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length),
inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
)
else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params
inference_context, self.decoder, decoder_input, self.config, packed_seq_params
)
rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len,
packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd',
)
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if (
(self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None
and inference_params
and inference_context
and inference_context.is_static_batching()
and not self.training
):
sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size,
[inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (
inference_context is not None
and not self.training
and not has_config_logger_enabled(self.config)
):
decoder_input = WrappedTensor(decoder_input)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
......@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel):
**(extra_block_kwargs or {}),
)
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss
output_weight = None
if self.share_embeddings_and_output_weights:
......@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel):
if not self.post_process:
return hidden_states
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
)
......
from megatron.core.transformer.moe.token_dispatcher import _DeepepManager as MegatronCoreDeepepManager
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
def token_permutation(
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class MoEAlltoAllPerBatchState:
def __init__(self, build_event=False):
self.num_global_tokens_per_local_expert = None
self.output_splits_tp = None
self.output_splits = None
self.input_splits = None
self.num_out_tokens = None
self.capacity = None
self.preprocess_event = None
self.hidden_shape = None
self.probs = None
self.routing_map = None
self.reversed_local_input_permutation_mapping = None
self.cuda_sync_point = None
self.hidden_shape_before_permute = None
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr(
self, "num_global_tokens_per_local_expert", None
)
state.output_splits_tp = getattr(self, "output_splits_tp", None)
state.output_splits = getattr(self, "output_splits", None)
state.input_splits = getattr(self, "input_splits", None)
state.num_out_tokens = getattr(self, "num_out_tokens", None)
state.capacity = getattr(self, "capacity", None)
state.preprocess_event = getattr(self, "preprocess_event", None)
state.hidden_shape = getattr(self, "hidden_shape", None)
state.probs = getattr(self, "probs", None)
state.routing_map = getattr(self, "routing_map", None)
state.reversed_local_input_permutation_mapping = getattr(
self, "reversed_local_input_permutation_mapping", None
)
state.hidden_shape_before_permute = getattr(self, "hidden_shape_before_permute", None)
state.cuda_sync_point = getattr(self, "cuda_sync_point", None)
def apply_per_batch_state(self, state: MoEAlltoAllPerBatchState):
self.num_global_tokens_per_local_expert = state.num_global_tokens_per_local_expert
self.output_splits_tp = state.output_splits_tp
self.output_splits = state.output_splits
self.input_splits = state.input_splits
self.num_out_tokens = state.num_out_tokens
self.capacity = state.capacity
self.preprocess_event = state.preprocess_event
self.hidden_shape = state.hidden_shape
self.probs = state.probs
self.routing_map = state.routing_map
self.reversed_local_input_permutation_mapping = (
state.reversed_local_input_permutation_mapping
)
self.hidden_shape_before_permute = state.hidden_shape_before_permute
self.cuda_sync_point = state.cuda_sync_point
@contextmanager
def per_batch_state_context(self, state: MoEAlltoAllPerBatchState):
origin_state = MoEAlltoAllPerBatchState()
self.collect_per_batch_state(origin_state)
try:
self.apply_per_batch_state(state)
yield
finally:
self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state)
def meta_prepare(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
):
self.hidden_shape = hidden_states.shape
self.probs = probs
self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
......@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
) = permute(
hidden_states,
routing_map,
probs=probs,
probs=self.probs,
num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
# Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert
......@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits
)
return tokens_per_expert, global_input_tokens, global_probs
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs):
if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
......@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert, global_probs
return global_input_tokens, tokens_per_expert, global_probs
class _DeepepManager(MegatronCoreDeepepManager):
"""
patch megatron _DeepepManager. async
"""
def dispatch(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
# DeepEP only supports float32 probs
if self.token_probs.dtype != torch.float32:
if self.token_probs.dtype in [torch.bfloat16, torch.float16]:
print("DeepEP only supports float32 probs, please set --moe-router-dtype=fp32")
self.token_probs = self.token_probs.float() # downcast or upcast
hidden_states, dispatched_indices, dispatched_probs, num_tokens_per_expert, handle = (
fused_dispatch(
hidden_states,
self.token_indices,
self.token_probs,
self.num_experts,
self.group,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
)
self.handle = handle
self.tokens_per_expert = num_tokens_per_expert
self.dispatched_indices = dispatched_indices
self.dispatched_probs = dispatched_probs
return hidden_states
def combine(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
hidden_states, _ = fused_combine(
hidden_states,
self.group,
self.handle,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
# Release the handle after combine operation
self.handle = None
return hidden_states
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Dispatch tokens to local experts using AlltoAll communication.
class MoEFlexTokenDispatcher(MoETokenDispatcher):
"""
Flex token dispatcher using DeepEP.
"""
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
def dispatch_preprocess(
self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
hidden_states (torch.Tensor): Input hidden states to be processed
routing_map (torch.Tensor): Map indicating which expert each token should be routed to
probs (torch.Tensor): Routing probabilities for each token-expert pair
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple containing:
- torch.Tensor: Reshaped hidden states
- torch.Tensor: Token probabilities from the communication manager
- None: Placeholder for compatibility
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
self.hidden_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
# Initialize metadata
routing_map, probs = self._initialize_metadata(routing_map, probs)
# Perform expert parallel AlltoAll communication
tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
self._comm_manager.setup_metadata(routing_map, probs)
return hidden_states, self._comm_manager.token_probs, None
# Permutation 2: Sort tokens by local expert.
global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
def dispatch_all_to_all(
self,
hidden_states: torch.Tensor,
probs: torch.Tensor = None,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return (
self._comm_manager.dispatch(hidden_states, async_finish, allocate_on_comm_stream),
self._comm_manager.dispatched_probs,
)
return global_input_tokens, tokens_per_expert, global_probs
def dispatch_postprocess(self, hidden_states: torch.Tensor):
"""
Post-processes the dispatched hidden states after all-to-all communication.
def combine_preprocess(self, hidden_states):
# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
if self.drop_and_pad:
hidden_states = (
hidden_states.view(
self.num_local_experts,
self.tp_size * self.ep_size,
self.capacity,
*hidden_states.size()[1:],
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
hidden_states, _ = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert.T.ravel(),
self.restore_output_by_local_experts,
fused=self.config.moe_permute_fusion,
)
This method retrieves the permuted hidden states by experts, calculates the number of tokens
per expert, and returns the processed data ready for expert processing.
"""
global_input_tokens, permuted_probs = (
self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states)
)
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert()
return global_input_tokens, tokens_per_expert, permuted_probs
if self.tp_size > 1:
if self.output_splits_tp is None:
input_split_sizes = None
else:
input_split_sizes = self.output_splits_tp.tolist()
# The precision of TP reduce_scatter should be the same as the router_dtype
hidden_states = reduce_scatter_to_sequence_parallel_region(
hidden_states.to(self.probs.dtype),
group=self.tp_group,
input_split_sizes=input_split_sizes,
).to(hidden_states.dtype)
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
return hidden_states
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states, _, _ = self.dispatch_preprocess(hidden_states, routing_map, probs)
hidden_states, _ = self.dispatch_all_to_all(
hidden_states, async_finish=False, allocate_on_comm_stream=False
)
global_input_tokens, tokens_per_expert, permuted_probs = self.dispatch_postprocess(
hidden_states
def combine_all_to_all(self, hidden_states):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits
)
return permutated_local_input_tokens
return global_input_tokens, tokens_per_expert, permuted_probs
def combine_preprocess(self, hidden_states: torch.Tensor):
"""
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing
by using the communication manager's restoration function.
"""
hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states)
return hidden_states
def combine_postprocess(self, permutated_local_input_tokens):
if self.shared_experts is not None:
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
self.shared_experts.post_forward_comm()
def combine_all_to_all(
self,
hidden_states: torch.Tensor,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to combine tokens after expert processing.
"""
return self._comm_manager.combine(hidden_states, async_finish, allocate_on_comm_stream)
# Unpermutation 1: AlltoAll output to output
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
restore_shape=self.hidden_shape_before_permute,
routing_map=self.routing_map,
fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad,
)
def combine_postprocess(self, hidden_states: torch.Tensor):
"""
Post-processes the combined hidden states after all-to-all communication.
# Reshape the output tensor
output = output.view(self.hidden_shape)
This method reshapes the combined hidden states to match the original input shape.
"""
return hidden_states.view(self.hidden_shape)
# Add shared experts output
if self.shared_experts is not None:
shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
return output
def token_unpermutation(
self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Reverses the token permutation process to restore the original token order.
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
This method implements the token unpermutation process in three steps:
1. Pre-process the hidden states to restore their original ordering
2. Perform all-to-all communication to combine tokens
3. Post-process the combined tokens to match the original input shape
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states)
hidden_states = self.combine_all_to_all(hidden_states, False, False)
hidden_states = self.combine_postprocess(hidden_states)
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states)
permutated_local_input_tokens = self.combine_all_to_all(hidden_states)
output = self.combine_postprocess(permutated_local_input_tokens)
return hidden_states, None
return output, None
from dataclasses import dataclass
from typing import Callable, Optional
@dataclass
class SubmoduleCallables:
"""
Holds references to forward, dgrad, and dw (weight-grad) callables
for a particular submodule.
"""
forward: Optional[Callable] = None
backward: Optional[Callable] = None
dgrad: Optional[Callable] = None
dw: Optional[Callable] = None
@dataclass
class TransformerLayerSubmoduleCallables:
"""
Collects the SubmoduleMethods for each of the submodules:
'attention', 'dispatch', 'mlp', 'combine'.
"""
attention: SubmoduleCallables
dispatch: SubmoduleCallables
mlp: SubmoduleCallables
combine: SubmoduleCallables
post_combine: SubmoduleCallables
def as_array(self):
return [self.attention, self.dispatch, self.mlp, self.combine, self.post_combine]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment