Unverified Commit a889c854 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

[Grammar Fix] GLM-4-MOE self.first_k_dense_replace is undefined. (#12455)

parent 4d84f886
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights""" """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -84,6 +84,7 @@ from sglang.srt.utils import ( ...@@ -84,6 +84,7 @@ from sglang.srt.utils import (
is_cpu, is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
is_non_idle_and_non_empty,
make_layers, make_layers,
) )
...@@ -142,14 +143,17 @@ class Glm4MoeMLP(nn.Module): ...@@ -142,14 +143,17 @@ class Glm4MoeMLP(nn.Module):
self, self,
x, x,
forward_batch=None, forward_batch=None,
should_allreduce_fusion=False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
): ):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
return x return x
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion) x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
return x return x
...@@ -442,63 +446,14 @@ class Glm4MoeSparseMoeBlock(nn.Module): ...@@ -442,63 +446,14 @@ class Glm4MoeSparseMoeBlock(nn.Module):
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not self._enable_a2a_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not get_moe_a2a_backend().is_deepep():
if ( return self.forward_normal(
self.alt_stream is not None hidden_states, should_allreduce_fusion, use_reduce_scatter
and hidden_states.shape[0] > 0 )
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_normal(
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(
parallel_state.get_tp_group(), disabled=not is_allocation_symmetric()
):
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_normal( def forward_normal(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -534,11 +489,13 @@ class Glm4MoeSparseMoeBlock(nn.Module): ...@@ -534,11 +489,13 @@ class Glm4MoeSparseMoeBlock(nn.Module):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch): def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
shared_output = None shared_output = None
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
topk_output = self.topk( topk_output = self.topk(
hidden_states, hidden_states,
...@@ -556,7 +513,15 @@ class Glm4MoeSparseMoeBlock(nn.Module): ...@@ -556,7 +513,15 @@ class Glm4MoeSparseMoeBlock(nn.Module):
) )
if shared_output is not None: if shared_output is not None:
final_hidden_states.add_(shared_output) x = shared_output
if self.experts.should_fuse_routed_scaling_factor_in_topk:
x.add_(final_hidden_states)
else:
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
final_hidden_states *= self.routed_scaling_factor
return final_hidden_states return final_hidden_states
...@@ -566,6 +531,82 @@ class Glm4MoeSparseMoeBlock(nn.Module): ...@@ -566,6 +531,82 @@ class Glm4MoeSparseMoeBlock(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
return shared_output return shared_output
def op_gate(self, state):
if is_non_idle_and_non_empty(
state.forward_batch.forward_mode, state.hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
state.router_logits = self.gate(state.hidden_states_mlp_input)
else:
state.router_logits = None
def op_select_experts(self, state):
router_logits = state.pop("router_logits")
hidden_states = state.hidden_states_mlp_input
if router_logits is not None:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.topk_output = self.topk(
hidden_states=hidden_states,
router_logits=router_logits,
num_token_non_padded=state.forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
state.topk_output = self.topk.empty_topk_output(hidden_states.device)
def op_dispatch_a(self, state):
if self.ep_size > 1:
self.experts.dispatcher.dispatch_a(
hidden_states=state.hidden_states_mlp_input,
topk_output=state.pop("topk_output"),
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_dispatch_b(self, state):
if self.ep_size > 1:
with get_global_expert_distribution_recorder().with_current_layer(
self.layer_id
):
state.dispatch_output = self.experts.dispatcher.dispatch_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_experts(self, state):
state.combine_input = self.experts.run_moe_core(
dispatch_output=state.dispatch_output,
)
def op_combine_a(self, state):
if self.ep_size > 1:
self.experts.dispatcher.combine_a(
combine_input=state.pop("combine_input"),
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
state.pop("dispatch_output")
def op_combine_b(self, state):
if self.ep_size > 1:
state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
tbo_subbatch_index=state.get("tbo_subbatch_index"),
)
def op_output(self, state):
final_hidden_states = state.pop("hidden_states_after_combine")
if (shared_output := state.pop("shared_output")) is not None:
x = shared_output
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
final_hidden_states = x
else:
final_hidden_states *= self.routed_scaling_factor
state.hidden_states_mlp_output = final_hidden_states
class Glm4MoeDecoderLayer(nn.Module): class Glm4MoeDecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -670,6 +711,7 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -670,6 +711,7 @@ class Glm4MoeDecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
...@@ -684,14 +726,96 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -684,14 +726,96 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
hidden_states = self.mlp(hidden_states, forward_batch) should_allreduce_fusion = (
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
forward_batch
)
)
hidden_states, residual = self.layer_communicator.postprocess_layer( # For DP with padding, reduce scatter can be used instead of all-reduce.
hidden_states, residual, forward_batch use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
) )
if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True
else:
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual return hidden_states, residual
def op_comm_prepare_attn(
self,
state,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
tbo_subbatch_index: Optional[int] = None,
):
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
)
state.update(
dict(
forward_batch=forward_batch,
positions=positions,
tbo_subbatch_index=tbo_subbatch_index,
)
)
def op_comm_prepare_mlp(self, state):
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
self.layer_communicator.prepare_mlp(
state.pop("hidden_states_after_attn"),
state.pop("residual_after_input_ln"),
state.forward_batch,
)
)
def op_mlp(self, state):
hidden_states = state.pop("hidden_states_mlp_input")
if not (
enable_moe_dense_fully_dp()
and (not self.is_layer_sparse)
and hidden_states.shape[0] == 0
):
state.hidden_states_mlp_output = self.mlp(
hidden_states, state.forward_batch
)
else:
state.hidden_states_mlp_output = hidden_states
def op_comm_postprocess_layer(self, state):
hidden_states, residual = self.layer_communicator.postprocess_layer(
state.pop("hidden_states_mlp_output"),
state.pop("residual_after_comm_pre_mlp"),
state.forward_batch,
)
output = dict(
positions=state.positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=state.forward_batch,
tbo_subbatch_index=state.tbo_subbatch_index,
)
state.clear(
expect_keys={
"positions",
"forward_batch",
"tbo_subbatch_index",
}
)
return output
class Glm4MoeModel(nn.Module): class Glm4MoeModel(nn.Module):
def __init__( def __init__(
...@@ -704,6 +828,7 @@ class Glm4MoeModel(nn.Module): ...@@ -704,6 +828,7 @@ class Glm4MoeModel(nn.Module):
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
if self.pp_group.is_first_rank: if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
...@@ -733,6 +858,8 @@ class Glm4MoeModel(nn.Module): ...@@ -733,6 +858,8 @@ class Glm4MoeModel(nn.Module):
else: else:
self.norm = PPMissingLayer(return_tuple=True) self.norm = PPMissingLayer(return_tuple=True)
self.layers_to_capture = []
def get_input_embeddings(self) -> torch.Tensor: def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens return self.embed_tokens
...@@ -766,8 +893,11 @@ class Glm4MoeModel(nn.Module): ...@@ -766,8 +893,11 @@ class Glm4MoeModel(nn.Module):
elif self.first_k_dense_replace < normal_start_layer: elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0 normal_end_layer = normal_start_layer = 0
aux_hidden_states = []
for i in range(normal_start_layer, normal_end_layer): for i in range(normal_start_layer, normal_end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i): with get_global_expert_distribution_recorder().with_current_layer(i):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
...@@ -802,7 +932,9 @@ class Glm4MoeModel(nn.Module): ...@@ -802,7 +932,9 @@ class Glm4MoeModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
else: else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states return hidden_states
return hidden_states, aux_hidden_states
class Glm4MoeForCausalLM(nn.Module): class Glm4MoeForCausalLM(nn.Module):
...@@ -813,10 +945,10 @@ class Glm4MoeForCausalLM(nn.Module): ...@@ -813,10 +945,10 @@ class Glm4MoeForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.pp_group = get_pp_group()
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = Glm4MoeModel( self.model = Glm4MoeModel(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
...@@ -847,10 +979,13 @@ class Glm4MoeForCausalLM(nn.Module): ...@@ -847,10 +979,13 @@ class Glm4MoeForCausalLM(nn.Module):
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
) )
else: else:
return hidden_states return hidden_states
...@@ -1027,5 +1162,19 @@ class Glm4MoeForCausalLM(nn.Module): ...@@ -1027,5 +1162,19 @@ class Glm4MoeForCausalLM(nn.Module):
num_groups=config.n_group, num_groups=config.n_group,
) )
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
EntryClass = [Glm4MoeForCausalLM] EntryClass = [Glm4MoeForCausalLM]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment