Commit 32ee381a authored by dongcl's avatar dongcl
Browse files

a2a overlap

parent 12b56c98
...@@ -9,6 +9,7 @@ from torch import Tensor ...@@ -9,6 +9,7 @@ from torch import Tensor
from megatron.core import InferenceParams, tensor_parallel from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel
...@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask: Tensor, attention_mask: Tensor,
decoder_input: Tensor = None, decoder_input: Tensor = None,
labels: Tensor = None, labels: Tensor = None,
inference_params: InferenceParams = None, inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None, packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None, extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None, runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None, loss_mask: Optional[Tensor] = None,
): ):
"""Builds a computation schedule plan for the model. """Builds a computation schedule plan for the model.
...@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask, attention_mask,
decoder_input=decoder_input, decoder_input=decoder_input,
labels=labels, labels=labels,
inference_params=inference_params, inference_context=inference_context,
packed_seq_params=packed_seq_params, packed_seq_params=packed_seq_params,
extra_block_kwargs=extra_block_kwargs, extra_block_kwargs=extra_block_kwargs,
runtime_gather_output=runtime_gather_output, runtime_gather_output=runtime_gather_output,
inference_params=inference_params,
loss_mask=loss_mask,
) )
def forward( def forward(
...@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask: Tensor, attention_mask: Tensor,
decoder_input: Tensor = None, decoder_input: Tensor = None,
labels: Tensor = None, labels: Tensor = None,
inference_params: InferenceParams = None, inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None, packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None, extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None, runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None, loss_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors """Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post through the embedding layer, and then the decoeder and finally into the post
processing layer (optional). processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units It either returns the Loss values if labels are given or the final hidden units
...@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored. # If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context = deprecate_inference_params(inference_context, inference_params)
# Decoder embedding. # Decoder embedding.
if decoder_input is not None: if decoder_input is not None:
pass pass
...@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel):
rotary_pos_cos = None rotary_pos_cos = None
rotary_pos_sin = None rotary_pos_sin = None
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if not self.training and self.config.flash_decode and inference_params: if not self.training and self.config.flash_decode and inference_context:
assert (
inference_context.is_static_batching()
), "GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE # Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
inference_params.max_sequence_length, inference_context.max_sequence_length,
self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
) )
else: else:
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.decoder, decoder_input, self.config, packed_seq_params inference_context, self.decoder, decoder_input, self.config, packed_seq_params
) )
rotary_pos_emb = self.rotary_pos_emb( rotary_pos_emb = self.rotary_pos_emb(
rotary_seq_len, rotary_seq_len,
packed_seq=packed_seq_params is not None packed_seq=packed_seq_params is not None
and packed_seq_params.qkv_format == 'thd', and packed_seq_params.qkv_format == 'thd',
) )
elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention:
if self.training or not self.config.flash_decode:
rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section)
else:
# Flash decoding uses precomputed cos and sin for RoPE
raise NotImplementedError(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if ( if (
(self.config.enable_cuda_graph or self.config.flash_decode) (self.config.enable_cuda_graph or self.config.flash_decode)
and rotary_pos_cos is not None and rotary_pos_cos is not None
and inference_params and inference_context
and inference_context.is_static_batching()
and not self.training
): ):
sequence_len_offset = torch.tensor( sequence_len_offset = torch.tensor(
[inference_params.sequence_len_offset] * inference_params.current_batch_size, [inference_context.sequence_len_offset] * inference_context.current_batch_size,
dtype=torch.int32, dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
) )
else: else:
sequence_len_offset = None sequence_len_offset = None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if (
inference_context is not None
and not self.training
and not has_config_logger_enabled(self.config)
):
decoder_input = WrappedTensor(decoder_input)
# Run decoder. # Run decoder.
hidden_states = self.decoder( hidden_states = self.decoder(
hidden_states=decoder_input, hidden_states=decoder_input,
attention_mask=attention_mask, attention_mask=attention_mask,
inference_params=inference_params, inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos, rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin, rotary_pos_sin=rotary_pos_sin,
...@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel):
**(extra_block_kwargs or {}), **(extra_block_kwargs or {}),
) )
# Process inference output.
if inference_context and not inference_context.is_static_batching():
hidden_states = inference_context.last_token_logits(
hidden_states.squeeze(1).unsqueeze(0)
).unsqueeze(1)
# logits and loss # logits and loss
output_weight = None output_weight = None
if self.share_embeddings_and_output_weights: if self.share_embeddings_and_output_weights:
...@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel): ...@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel):
if not self.post_process: if not self.post_process:
return hidden_states return hidden_states
if (
not self.training
and inference_context is not None
and inference_context.is_static_batching()
and inference_context.materialize_only_last_token_logits
):
hidden_states = hidden_states[-1:, :, :]
logits, _ = self.output_layer( logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
) )
......
from megatron.core.transformer.moe.token_dispatcher import _DeepepManager as MegatronCoreDeepepManager from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher as MegatronCoreMoEAlltoAllTokenDispatcher
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
def token_permutation( # decouple perbatch state from MoEAlltoAllTokenDispatcher
class MoEAlltoAllPerBatchState:
def __init__(self, build_event=False):
self.num_global_tokens_per_local_expert = None
self.output_splits_tp = None
self.output_splits = None
self.input_splits = None
self.num_out_tokens = None
self.capacity = None
self.preprocess_event = None
self.hidden_shape = None
self.probs = None
self.routing_map = None
self.reversed_local_input_permutation_mapping = None
self.cuda_sync_point = None
self.hidden_shape_before_permute = None
class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
def collect_per_batch_state(self, state: MoEAlltoAllPerBatchState):
state.num_global_tokens_per_local_expert = getattr(
self, "num_global_tokens_per_local_expert", None
)
state.output_splits_tp = getattr(self, "output_splits_tp", None)
state.output_splits = getattr(self, "output_splits", None)
state.input_splits = getattr(self, "input_splits", None)
state.num_out_tokens = getattr(self, "num_out_tokens", None)
state.capacity = getattr(self, "capacity", None)
state.preprocess_event = getattr(self, "preprocess_event", None)
state.hidden_shape = getattr(self, "hidden_shape", None)
state.probs = getattr(self, "probs", None)
state.routing_map = getattr(self, "routing_map", None)
state.reversed_local_input_permutation_mapping = getattr(
self, "reversed_local_input_permutation_mapping", None
)
state.hidden_shape_before_permute = getattr(self, "hidden_shape_before_permute", None)
state.cuda_sync_point = getattr(self, "cuda_sync_point", None)
def apply_per_batch_state(self, state: MoEAlltoAllPerBatchState):
self.num_global_tokens_per_local_expert = state.num_global_tokens_per_local_expert
self.output_splits_tp = state.output_splits_tp
self.output_splits = state.output_splits
self.input_splits = state.input_splits
self.num_out_tokens = state.num_out_tokens
self.capacity = state.capacity
self.preprocess_event = state.preprocess_event
self.hidden_shape = state.hidden_shape
self.probs = state.probs
self.routing_map = state.routing_map
self.reversed_local_input_permutation_mapping = (
state.reversed_local_input_permutation_mapping
)
self.hidden_shape_before_permute = state.hidden_shape_before_permute
self.cuda_sync_point = state.cuda_sync_point
@contextmanager
def per_batch_state_context(self, state: MoEAlltoAllPerBatchState):
origin_state = MoEAlltoAllPerBatchState()
self.collect_per_batch_state(origin_state)
try:
self.apply_per_batch_state(state)
yield
finally:
self.collect_per_batch_state(state)
self.apply_per_batch_state(origin_state)
def meta_prepare(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ):
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
self.probs = probs self.probs = probs
self.routing_map = routing_map self.routing_map = routing_map
assert probs.dim() == 2, "Expected 2D tensor for probs" assert probs.dim() == 2, "Expected 2D tensor for probs"
assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask" assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask"
assert routing_map.dtype == torch.bool, "Expected bool tensor for mask" assert routing_map.dtype == torch.bool, "Expected bool tensor for mask"
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self.preprocess(self.routing_map) tokens_per_expert = self.preprocess(self.routing_map)
return tokens_per_expert
def dispatch_preprocess(self, hidden_states: torch.Tensor, routing_map: torch.Tensor, tokens_per_expert: torch.Tensor):
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
if self.shared_experts is not None: if self.shared_experts is not None:
self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape))
...@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
) = permute( ) = permute(
hidden_states, hidden_states,
routing_map, routing_map,
probs=probs, probs=self.probs,
num_out_tokens=self.num_out_tokens, num_out_tokens=self.num_out_tokens,
fused=self.config.moe_permute_fusion, fused=self.config.moe_permute_fusion,
drop_and_pad=self.drop_and_pad, drop_and_pad=self.drop_and_pad,
) )
return tokens_per_expert, permutated_local_input_tokens, permuted_probs
def dispatch_all_to_all(self, tokens_per_expert, permutated_local_input_tokens, permuted_probs):
# Perform expert parallel AlltoAll communication # Perform expert parallel AlltoAll communication
tokens_per_expert = self._maybe_dtoh_and_synchronize( tokens_per_expert = self._maybe_dtoh_and_synchronize(
"before_ep_alltoall", tokens_per_expert "before_ep_alltoall", tokens_per_expert
...@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
global_probs = all_to_all( global_probs = all_to_all(
self.ep_group, permuted_probs, self.output_splits, self.input_splits self.ep_group, permuted_probs, self.output_splits, self.input_splits
) )
return tokens_per_expert, global_input_tokens, global_probs
def dispatch_postprocess(self, tokens_per_expert, global_input_tokens, global_probs):
if self.shared_experts is not None: if self.shared_experts is not None:
self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) self.shared_experts.linear_fc1_forward_and_act(global_input_tokens)
...@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): ...@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
) )
tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert) tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return global_input_tokens, tokens_per_expert, global_probs return global_input_tokens, tokens_per_expert, global_probs
class _DeepepManager(MegatronCoreDeepepManager): def token_permutation(
""" self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
patch megatron _DeepepManager. async ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Dispatch tokens to local experts using AlltoAll communication.
def dispatch( This method performs the following steps:
self, 1. Preprocess the routing map to get metadata for communication and permutation.
hidden_states: torch.Tensor, 2. Permute input tokens for AlltoAll communication.
async_finish: bool = False, 3. Perform expert parallel AlltoAll communication.
allocate_on_comm_stream: bool = False, 4. Sort tokens by local expert (if multiple local experts exist).
) -> torch.Tensor:
# DeepEP only supports float32 probs
if self.token_probs.dtype != torch.float32:
if self.token_probs.dtype in [torch.bfloat16, torch.float16]:
print("DeepEP only supports float32 probs, please set --moe-router-dtype=fp32")
self.token_probs = self.token_probs.float() # downcast or upcast
hidden_states, dispatched_indices, dispatched_probs, num_tokens_per_expert, handle = (
fused_dispatch(
hidden_states,
self.token_indices,
self.token_probs,
self.num_experts,
self.group,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
)
self.handle = handle
self.tokens_per_expert = num_tokens_per_expert
self.dispatched_indices = dispatched_indices
self.dispatched_probs = dispatched_probs
return hidden_states
def combine(
self,
hidden_states: torch.Tensor,
async_finish: bool = False,
allocate_on_comm_stream: bool = False,
) -> torch.Tensor:
hidden_states, _ = fused_combine(
hidden_states,
self.group,
self.handle,
async_finish=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
# Release the handle after combine operation
self.handle = None
return hidden_states
class MoEFlexTokenDispatcher(MoETokenDispatcher):
"""
Flex token dispatcher using DeepEP.
"""
def dispatch_preprocess(
self, hidden_states: torch.Tensor, routing_map: torch.Tensor, probs: torch.Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args: Args:
hidden_states (torch.Tensor): Input hidden states to be processed hidden_states (torch.Tensor): Input token embeddings.
routing_map (torch.Tensor): Map indicating which expert each token should be routed to probs (torch.Tensor): The probabilities of token to experts assignment.
probs (torch.Tensor): Routing probabilities for each token-expert pair routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns: Returns:
Tuple containing: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- torch.Tensor: Reshaped hidden states - Permuted token embeddings for local experts.
- torch.Tensor: Token probabilities from the communication manager - Number of tokens per expert.
- None: Placeholder for compatibility - Permuted probs of each token produced by the router.
""" """
self.hidden_shape = hidden_states.shape # Preprocess: Get the metadata for communication, permutation and computation operations.
hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) # Permutation 1: input to AlltoAll input
tokens_per_expert = self.meta_prepare(hidden_states, probs, routing_map)
# Initialize metadata tokens_per_expert, permutated_local_input_tokens, permuted_probs = self.dispatch_preprocess(hidden_states, routing_map, tokens_per_expert)
routing_map, probs = self._initialize_metadata(routing_map, probs)
self._comm_manager.setup_metadata(routing_map, probs) # Perform expert parallel AlltoAll communication
return hidden_states, self._comm_manager.token_probs, None tokens_per_expert, global_input_tokens, global_probs = self.dispatch_all_to_all(tokens_per_expert, permutated_local_input_tokens, permuted_probs)
def dispatch_all_to_all( # Permutation 2: Sort tokens by local expert.
self, global_input_tokens, tokens_per_expert, global_probs = self.dispatch_postprocess(tokens_per_expert, global_input_tokens, global_probs)
hidden_states: torch.Tensor,
probs: torch.Tensor = None,
async_finish: bool = True,
allocate_on_comm_stream: bool = True,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return (
self._comm_manager.dispatch(hidden_states, async_finish, allocate_on_comm_stream),
self._comm_manager.dispatched_probs,
)
def dispatch_postprocess(self, hidden_states: torch.Tensor): return global_input_tokens, tokens_per_expert, global_probs
"""
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens def combine_preprocess(self, hidden_states):
per expert, and returns the processed data ready for expert processing. # Unpermutation 2: Unsort tokens by local expert.
""" if self.num_local_experts > 1:
global_input_tokens, permuted_probs = ( if self.drop_and_pad:
self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) hidden_states = (
hidden_states.view(
self.num_local_experts,
self.tp_size * self.ep_size,
self.capacity,
*hidden_states.size()[1:],
) )
tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() .transpose(0, 1)
return global_input_tokens, tokens_per_expert, permuted_probs .contiguous()
.flatten(start_dim=0, end_dim=2)
def token_permutation(
self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states, _, _ = self.dispatch_preprocess(hidden_states, routing_map, probs)
hidden_states, _ = self.dispatch_all_to_all(
hidden_states, async_finish=False, allocate_on_comm_stream=False
) )
global_input_tokens, tokens_per_expert, permuted_probs = self.dispatch_postprocess( else:
hidden_states hidden_states, _ = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert.T.ravel(),
self.restore_output_by_local_experts,
fused=self.config.moe_permute_fusion,
) )
return global_input_tokens, tokens_per_expert, permuted_probs if self.tp_size > 1:
if self.output_splits_tp is None:
input_split_sizes = None
else:
input_split_sizes = self.output_splits_tp.tolist()
# The precision of TP reduce_scatter should be the same as the router_dtype
hidden_states = reduce_scatter_to_sequence_parallel_region(
hidden_states.to(self.probs.dtype),
group=self.tp_group,
input_split_sizes=input_split_sizes,
).to(hidden_states.dtype)
return hidden_states
def combine_preprocess(self, hidden_states: torch.Tensor): def combine_all_to_all(self, hidden_states):
""" # Perform expert parallel AlltoAll communication
Pre-processes the hidden states before combining them after expert processing. # hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens = all_to_all(
self.ep_group, hidden_states, self.input_splits, self.output_splits
)
return permutated_local_input_tokens
This method restores the hidden states to their original ordering before expert processing def combine_postprocess(self, permutated_local_input_tokens):
by using the communication manager's restoration function. if self.shared_experts is not None:
""" self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states) self.shared_experts.post_forward_comm()
return hidden_states
def combine_all_to_all( # Unpermutation 1: AlltoAll output to output
self, output = unpermute(
hidden_states: torch.Tensor, permutated_local_input_tokens,
async_finish: bool = True, self.reversed_local_input_permutation_mapping,
allocate_on_comm_stream: bool = True, restore_shape=self.hidden_shape_before_permute,
): routing_map=self.routing_map,
""" fused=self.config.moe_permute_fusion,
Performs all-to-all communication to combine tokens after expert processing. drop_and_pad=self.drop_and_pad,
""" )
return self._comm_manager.combine(hidden_states, async_finish, allocate_on_comm_stream)
def combine_postprocess(self, hidden_states: torch.Tensor): # Reshape the output tensor
""" output = output.view(self.hidden_shape)
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape. # Add shared experts output
""" if self.shared_experts is not None:
return hidden_states.view(self.hidden_shape) shared_expert_output = self.shared_experts.get_output()
output += shared_expert_output
return output
def token_unpermutation( def token_unpermutation(
self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
""" """
Reverses the token permutation process to restore the original token order. Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
This method implements the token unpermutation process in three steps: Args:
1. Pre-process the hidden states to restore their original ordering hidden_states (torch.Tensor): Output from local experts.
2. Perform all-to-all communication to combine tokens bias (torch.Tensor, optional): Bias tensor (not supported).
3. Post-process the combined tokens to match the original input shape
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
""" """
assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states = self.combine_preprocess(hidden_states) hidden_states = self.combine_preprocess(hidden_states)
hidden_states = self.combine_all_to_all(hidden_states, False, False) permutated_local_input_tokens = self.combine_all_to_all(hidden_states)
hidden_states = self.combine_postprocess(hidden_states) output = self.combine_postprocess(permutated_local_input_tokens)
return hidden_states, None return output, None
from dataclasses import dataclass
from typing import Callable, Optional
@dataclass
class SubmoduleCallables:
"""
Holds references to forward, dgrad, and dw (weight-grad) callables
for a particular submodule.
"""
forward: Optional[Callable] = None
backward: Optional[Callable] = None
dgrad: Optional[Callable] = None
dw: Optional[Callable] = None
@dataclass
class TransformerLayerSubmoduleCallables:
"""
Collects the SubmoduleMethods for each of the submodules:
'attention', 'dispatch', 'mlp', 'combine'.
"""
attention: SubmoduleCallables
dispatch: SubmoduleCallables
mlp: SubmoduleCallables
combine: SubmoduleCallables
post_combine: SubmoduleCallables
def as_array(self):
return [self.attention, self.dispatch, self.mlp, self.combine, self.post_combine]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment