"tests/vscode:/vscode.git/clone" did not exist on "1329be960638223df45ec4d2f679aa4df2556b6e"
Unverified Commit 6d6a8bc2 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

GLM-4.5 Model Support (#8224)


Co-authored-by: default avatarLifu Huang <lifu.hlf@gmail.com>
Co-authored-by: default avatarBinyao Jiang <byjiang1996@gmail.com>
Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
parent 2fd5c704
...@@ -33,7 +33,11 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -33,7 +33,11 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = ( E = (
config.n_routed_experts + 1 config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"] if config.architectures[0] in ["DeepseekV3ForCausalLM"]
......
...@@ -42,7 +42,11 @@ def get_model_config(model_name: str, tp_size: int): ...@@ -42,7 +42,11 @@ def get_model_config(model_name: str, tp_size: int):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"Glm4MoeForCausalLM",
]:
E = ( E = (
config.n_routed_experts + 1 config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"] if config.architectures[0] in ["DeepseekV3ForCausalLM"]
......
...@@ -127,6 +127,9 @@ class ModelConfig: ...@@ -127,6 +127,9 @@ class ModelConfig:
): ):
self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
self.hf_config.architectures[0] = "MiMoMTP" self.hf_config.architectures[0] = "MiMoMTP"
# Check model type # Check model type
......
...@@ -165,6 +165,7 @@ class EBNFComposer: ...@@ -165,6 +165,7 @@ class EBNFComposer:
tool_call_separator: Optional[str] = None, tool_call_separator: Optional[str] = None,
call_rule_fmt: Optional[str] = None, call_rule_fmt: Optional[str] = None,
key_value_rule_fmt: Optional[str] = None, key_value_rule_fmt: Optional[str] = None,
key_value_separator: str = ",",
): ):
""" """
Generalized EBNF builder for all detectors. Generalized EBNF builder for all detectors.
...@@ -279,7 +280,11 @@ class EBNFComposer: ...@@ -279,7 +280,11 @@ class EBNFComposer:
# Add required properties joined by commas # Add required properties joined by commas
if required: if required:
rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required)) rule_parts.append(
f' "{key_value_separator}" '.join(
prop_kv_pairs[k] for k in required
)
)
# Add optional properties with flexible ordering # Add optional properties with flexible ordering
if optional: if optional:
...@@ -292,13 +297,15 @@ class EBNFComposer: ...@@ -292,13 +297,15 @@ class EBNFComposer:
if j == i: if j == i:
opt_parts.append(prop_kv_pairs[optional[j]]) opt_parts.append(prop_kv_pairs[optional[j]])
else: else:
opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?') opt_parts.append(
f' ( "{key_value_separator}" {prop_kv_pairs[optional[j]]} )?'
)
opt_alternatives.append("".join(opt_parts)) opt_alternatives.append("".join(opt_parts))
# Wrap with appropriate comma handling based on whether we have required properties # Wrap with appropriate comma handling based on whether we have required properties
if required: if required:
# Required properties exist, so optional group needs outer comma # Required properties exist, so optional group needs outer comma
rule_parts.append(' ( "," ( ') rule_parts.append(f' ( "{key_value_separator}" ( ')
rule_parts.append(" | ".join(opt_alternatives)) rule_parts.append(" | ".join(opt_alternatives))
rule_parts.append(" ) )?") rule_parts.append(" ) )?")
else: else:
......
...@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.mistral_detector import MistralDetector
...@@ -37,6 +38,7 @@ class FunctionCallParser: ...@@ -37,6 +38,7 @@ class FunctionCallParser:
"pythonic": PythonicDetector, "pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector, "kimi_k2": KimiK2Detector,
"qwen3_coder": Qwen3CoderDetector, "qwen3_coder": Qwen3CoderDetector,
"glm45": Glm4MoeDetector,
} }
def __init__(self, tools: List[Tool], tool_call_parser: str): def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import ast
import json
import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
logger = logging.getLogger(__name__)
def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
name2tool = {tool.function.name: tool for tool in defined_tools}
if func_name not in name2tool:
return None
tool = name2tool[func_name]
if arg_key not in tool.function.parameters["properties"]:
return None
return tool.function.parameters["properties"][arg_key].get("type", None)
def parse_arguments(json_value):
try:
try:
parsed_value = json.loads(json_value)
except:
parsed_value = ast.literal_eval(json_value)
return parsed_value, True
except:
return json_value, False
class Glm4MoeDetector(BaseFormatDetector):
"""
Detector for GLM-4.5 models.
Assumes function call format:
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
"""
def __init__(self):
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
self.func_call_regex = r"<tool_call>.*?</tool_call>"
self.func_detail_regex = r"<tool_call>([^\n]*)\n(.*)</tool_call>"
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a glm-4.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
calls = []
try:
for match_result in match_result_list:
# Get function name
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
func_name = func_detail.group(1)
func_args = func_detail.group(2)
pairs = re.findall(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
func_args,
re.DOTALL,
)
arguments = {}
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
arg_value = arg_value.strip()
arg_type = get_argument_type(func_name, arg_key, tools)
if arg_type != "string":
arg_value, is_good_json = parse_arguments(arg_value)
arguments[arg_key] = arg_value
# construct match_result for parse_base_json
match_result = {"name": func_name, "parameters": arguments}
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for GLM-4.5 format.
"""
self._buffer += new_text
current_text = self._buffer
start = current_text.find(self.bot_token)
if start == -1:
self._buffer = ""
if self.current_tool_id > 0:
current_text = ""
return StreamingParseResult(normal_text=current_text)
# find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token)
end = current_text.find(self.eot_token)
if end != -1:
# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
result = self.detect_and_parse(
current_text[: end + len(self.eot_token)], tools=tools
)
if result.calls:
self.prev_tool_call_arr[self.current_tool_id] = {
"name": result.calls[0].name,
"arguments": json.loads(result.calls[0].parameters),
}
self.streamed_args_for_tool[self.current_tool_id] = result.calls[
0
].parameters
result.calls[0].tool_index = self.current_tool_id
self.current_tool_id += 1
self._buffer = current_text[end + len(self.eot_token) :]
return result
normal_text = current_text[:start]
self._buffer = current_text[start:]
return StreamingParseResult(normal_text=normal_text)
def supports_structural_tag(self) -> bool:
return False
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
individual_call_start_token=self.bot_token,
individual_call_end_token=self.eot_token,
# GLM4Moe is not compatible with multiple tool_calls under tool_choice condition: it will output unlimited tool_calls...
# tool_call_separator="\\n",
function_format="xml",
call_rule_fmt='"{name}" "\\n" {arguments_rule} "\\n"',
key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
key_value_separator="\\n",
)
# Copyright 2025-2026 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5 model compatible with HuggingFace weights"""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
enable_moe_dense_fully_dp,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV2ForCausalLM,
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
LazyValue,
add_prefix,
bind_or_assign,
cpu_has_amx_support,
get_bool_env_var,
get_device_sm,
get_int_env_var,
is_cpu,
is_cuda,
is_flashinfer_available,
is_hip,
is_non_idle_and_non_empty,
log_info_on_rank0,
use_intel_amx_backend,
)
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_device_sm = get_device_sm()
if _is_cuda:
from sgl_kernel import dsv3_router_gemm
elif _is_cpu and _is_cpu_amx_available:
pass
logger = logging.getLogger(__name__)
class Glm4MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.tp_size = tp_size
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
return x
class Glm4MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
partial_rotary_factor: float = 0.5,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-05,
attention_bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
use_qk_norm: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.use_qk_norm = use_qk_norm
self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank()
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
partial_rotary_factor=partial_rotary_factor,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.alt_stream = alt_stream
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k = k_by_head.view(k.shape)
return q, k
def op_prepare(self, state):
state.attn_intermediate_state = self.forward_prepare(
positions=state.positions,
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
forward_batch=state.forward_batch,
)
def op_core(self, state):
state.hidden_states_after_attn = self.forward_core(
state.pop("attn_intermediate_state")
)
def forward_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
if hidden_states.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
return self.forward_core(s)
class Glm4MoeGate(nn.Module):
def __init__(
self,
config,
prefix: str = "",
is_nextn: bool = False,
):
super().__init__()
self.is_nextn = is_nextn
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
)
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts))
)
if _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"])
def forward(self, hidden_states):
if use_intel_amx_backend(self):
return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states,
self.weight,
None, # bias
True, # is_vnni
)
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if (
_is_cuda
and not self.is_nextn
and hidden_states.shape[0] < 4
and hidden_states.shape[1] == 7168
and self.weight.shape[0] == 256
and _device_sm >= 90
):
logits = dsv3_router_gemm(hidden_states, self.weight).to(
hidden_states.dtype
)
else:
logits = F.linear(hidden_states, self.weight, None)
return logits
class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
is_nextn: bool = False,
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
else config.n_shared_experts
)
self.config = config
self.layer_id = layer_id
self.alt_stream = alt_stream
if self.tp_size > config.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.n_routed_experts}."
)
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
self.gate = Glm4MoeGate(
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = (
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not use_flashinfer_trtllm_moe
else None
)
self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if use_flashinfer_trtllm_moe
else {}
),
)
self.shared_experts_is_int8 = False
self.shared_experts_is_fp8 = False
# self.shared_experts_weight_block_size = None
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Glm4MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"]
else {}
),
)
is_packed_weight = hasattr(
self.shared_experts.gate_up_proj.quant_method, "quant_config"
)
self.shared_experts_is_int8 = (
not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
)
self.shared_experts_is_fp8 = (
not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
)
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
)
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
self.num_expert_group = config.n_group
self.correction_bias = (
self.gate.e_score_correction_bias.data
if self.gate.e_score_correction_bias is not None
else None
)
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
def __init__(
self,
config: PretrainedConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.config = config
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.self_attn = Glm4MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
partial_rotary_factor=partial_rotary_factor,
max_position_embeddings=max_position_embeddings,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
use_qk_norm=config.use_qk_norm,
)
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
num_layers = 1 if is_nextn else config.num_hidden_layers
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=num_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
if self.is_layer_sparse:
self.mlp = Glm4MoeSparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id,
)
else:
if enable_moe_dense_fully_dp():
mlp_tp_rank, mlp_tp_size = 0, 1
else:
mlp_tp_rank, mlp_tp_size = None, None
self.mlp = Glm4MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
tp_rank=mlp_tp_rank,
tp_size=mlp_tp_size,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
class Glm4MoeModel(DeepseekV2Model):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.padding_id = config.pad_token_id
self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList(
[
Glm4MoeDecoderLayer(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
alt_stream=self.alt_stream,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.model = Glm4MoeModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
)
def determine_num_fused_shared_experts(
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason = None
if (
not _is_cuda
or torch.cuda.get_device_capability("cuda") < (8, 0)
or self.config.architectures[0] != architecture
or self.config.n_routed_experts != 128
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif (
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
)
return
self.num_fused_shared_experts = self.config.n_shared_experts
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
# compatible with old design
nextn_layer_id = (
0
if self.config.num_hidden_layers == 1
else self.config.num_hidden_layers
)
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
"shared_head.norm",
"eh_proj",
"enorm",
"hnorm",
]
params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
weight_names.append(name)
if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if (
len(name_list) >= 3
and int(name_list[2]) >= self.config.num_hidden_layers
):
continue
else:
if not name.startswith(nextn_layer_prefix):
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = [Glm4MoeForCausalLM]
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5 NextN Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__)
class Glm4MoeModelNextN(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
)
quant_config = None
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
self.decoder = Glm4MoeDecoderLayer(
config,
0,
quant_config=quant_config,
is_nextn=True,
prefix=add_prefix("decoder", prefix),
)
self.shared_head = nn.Module()
self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
if hidden_states.shape[0] > 0:
hidden_states = self.eh_proj(
torch.cat(
(
self.enorm(hidden_states),
self.hnorm(forward_batch.spec_info.hidden_states),
),
dim=-1,
)
)
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is not None:
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
else:
hidden_states = self.shared_head.norm(hidden_states)
return hidden_states
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
super().load_weights(weights, is_nextn=True)
EntryClass = [Glm4MoeForCausalLMNextN]
...@@ -231,6 +231,7 @@ class ReasoningParser: ...@@ -231,6 +231,7 @@ class ReasoningParser:
"deepseek-r1": DeepSeekR1Detector, "deepseek-r1": DeepSeekR1Detector,
"qwen3": Qwen3Detector, "qwen3": Qwen3Detector,
"qwen3-thinking": Qwen3ThinkingDetector, "qwen3-thinking": Qwen3ThinkingDetector,
"glm45": Qwen3Detector,
"kimi": KimiDetector, "kimi": KimiDetector,
} }
......
...@@ -513,7 +513,7 @@ class ServerArgs: ...@@ -513,7 +513,7 @@ class ServerArgs:
) )
model_arch = self.get_hf_config().architectures[0] model_arch = self.get_hf_config().architectures[0]
if model_arch == "DeepseekV3ForCausalLM": if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
# Auto set draft_model_path DeepSeek-V3/R1 # Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None: if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path self.speculative_draft_model_path = self.model_path
...@@ -1108,6 +1108,7 @@ class ServerArgs: ...@@ -1108,6 +1108,7 @@ class ServerArgs:
"pythonic", "pythonic",
"kimi_k2", "kimi_k2",
"qwen3_coder", "qwen3_coder",
"glm45",
], ],
default=ServerArgs.tool_call_parser, default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.", help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', and 'qwen3_coder'.",
......
...@@ -2343,6 +2343,7 @@ def is_fa3_default_architecture(hf_config): ...@@ -2343,6 +2343,7 @@ def is_fa3_default_architecture(hf_config):
"Gemma3ForConditionalGeneration", "Gemma3ForConditionalGeneration",
"Qwen3ForCausalLM", "Qwen3ForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
"Glm4MoeForCausalLM",
} }
return architectures[0] in default_archs return architectures[0] in default_archs
......
...@@ -43,6 +43,7 @@ class TestEnableThinking(CustomTestCase): ...@@ -43,6 +43,7 @@ class TestEnableThinking(CustomTestCase):
"qwen3", "qwen3",
], ],
) )
cls.additional_chat_kwargs = {}
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -59,6 +60,7 @@ class TestEnableThinking(CustomTestCase): ...@@ -59,6 +60,7 @@ class TestEnableThinking(CustomTestCase):
"temperature": 0, "temperature": 0,
"separate_reasoning": True, "separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True}, "chat_template_kwargs": {"enable_thinking": True},
**self.additional_chat_kwargs,
}, },
) )
...@@ -82,6 +84,7 @@ class TestEnableThinking(CustomTestCase): ...@@ -82,6 +84,7 @@ class TestEnableThinking(CustomTestCase):
"temperature": 0, "temperature": 0,
"separate_reasoning": True, "separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": False}, "chat_template_kwargs": {"enable_thinking": False},
**self.additional_chat_kwargs,
}, },
) )
...@@ -107,6 +110,7 @@ class TestEnableThinking(CustomTestCase): ...@@ -107,6 +110,7 @@ class TestEnableThinking(CustomTestCase):
"separate_reasoning": True, "separate_reasoning": True,
"stream": True, "stream": True,
"chat_template_kwargs": {"enable_thinking": True}, "chat_template_kwargs": {"enable_thinking": True},
**self.additional_chat_kwargs,
}, },
stream=True, stream=True,
) )
...@@ -151,6 +155,7 @@ class TestEnableThinking(CustomTestCase): ...@@ -151,6 +155,7 @@ class TestEnableThinking(CustomTestCase):
"separate_reasoning": True, "separate_reasoning": True,
"stream": True, "stream": True,
"chat_template_kwargs": {"enable_thinking": False}, "chat_template_kwargs": {"enable_thinking": False},
**self.additional_chat_kwargs,
}, },
stream=True, stream=True,
) )
...@@ -184,5 +189,55 @@ class TestEnableThinking(CustomTestCase): ...@@ -184,5 +189,55 @@ class TestEnableThinking(CustomTestCase):
) )
## Skip for ci test
# class TestGLM45EnableThinking(TestEnableThinking):
# @classmethod
# def setUpClass(cls):
# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# cls.model = "THUDM/GLM-4.5"
# cls.base_url = DEFAULT_URL_FOR_TEST
# cls.api_key = "sk-1234"
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# api_key=cls.api_key,
# other_args=[
# "--tool-call-parser",
# "glm45",
# "--reasoning-parser",
# "glm45",
# "--tp-size",
# "8"
# ],
# )
# # Validate whether enable-thinking conflict with tool_calls
# cls.additional_chat_kwargs = {
# "tools": [
# {
# "type": "function",
# "function": {
# "name": "add",
# "description": "Compute the sum of two numbers",
# "parameters": {
# "type": "object",
# "properties": {
# "a": {
# "type": "int",
# "description": "A number",
# },
# "b": {
# "type": "int",
# "description": "A number",
# },
# },
# "required": ["a", "b"],
# },
# },
# }
# ]
# }
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -223,7 +223,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -223,7 +223,10 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
messages = [ messages = [
{"role": "system", "content": self.SYSTEM_MESSAGE}, {"role": "system", "content": self.SYSTEM_MESSAGE},
{"role": "user", "content": "What is the temperature in Paris?"}, {
"role": "user",
"content": "What is the temperature in Paris in celsius??",
},
] ]
response_stream = client.chat.completions.create( response_stream = client.chat.completions.create(
...@@ -910,5 +913,40 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): ...@@ -910,5 +913,40 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
) )
## Skip for ci test
# class TestGLM45ServerFunctionCalling(TestOpenAIServerFunctionCalling):
# @classmethod
# def setUpClass(cls):
# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# cls.model = "THUDM/GLM-4.5"
# cls.base_url = DEFAULT_URL_FOR_TEST
# cls.api_key = "sk-123456"
# # Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# api_key=cls.api_key,
# other_args=[
# # If your server needs extra parameters to test function calling, please add them here.
# "--tool-call-parser",
# "glm45",
# "--reasoning-parser",
# "glm45",
# "--tp-size",
# "8"
# ],
# )
# cls.base_url += "/v1"
# cls.tokenizer = get_tokenizer(cls.model)
# # This test is too difficult for GLM4-moe. Skip it from the UT
# def test_function_call_required(self):
# pass
# def test_function_calling_multiturn(self):
# self._test_function_calling_multiturn()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo ...@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.mistral_detector import MistralDetector
...@@ -510,6 +511,7 @@ class TestEBNFGeneration(unittest.TestCase): ...@@ -510,6 +511,7 @@ class TestEBNFGeneration(unittest.TestCase):
self.qwen25_detector = Qwen25Detector() self.qwen25_detector = Qwen25Detector()
self.qwen3_coder_detector = Qwen3CoderDetector() self.qwen3_coder_detector = Qwen3CoderDetector()
self.kimik2_detector = KimiK2Detector() self.kimik2_detector = KimiK2Detector()
self.glm45_detector = Glm4MoeDetector()
def test_pythonic_detector_ebnf(self): def test_pythonic_detector_ebnf(self):
"""Test that the PythonicDetector generates valid EBNF.""" """Test that the PythonicDetector generates valid EBNF."""
...@@ -622,6 +624,29 @@ class TestEBNFGeneration(unittest.TestCase): ...@@ -622,6 +624,29 @@ class TestEBNFGeneration(unittest.TestCase):
except RuntimeError as e: except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}") self.fail(f"Failed to compile EBNF: {e}")
def test_glm45_detector_ebnf(self):
"""Test that the Glm4MoeDetector generates valid EBNF."""
ebnf = self.glm45_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns for XML format
self.assertIn('"<tool_call>" function_call "</tool_call>"', ebnf)
self.assertIn('"get_weather" "\\n" arguments_get_weather', ebnf)
self.assertIn(
'"<arg_key>location</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>" ( "\\n" ( "<arg_key>unit</arg_key>" "\\n" "<arg_value>" ("celsius" | "fahrenheit") "</arg_value>" ) )?',
ebnf,
)
self.assertIn('"search" "\\n" arguments_search', ebnf)
self.assertIn(
'"<arg_key>query</arg_key>" "\\n" "<arg_value>" xml_text "</arg_value>"',
ebnf,
)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_qwen3_coder_detector_ebnf(self): def test_qwen3_coder_detector_ebnf(self):
"""Test that the Qwen3CoderDetector generates valid EBNF.""" """Test that the Qwen3CoderDetector generates valid EBNF."""
ebnf = self.qwen3_coder_detector.build_ebnf(self.tools) ebnf = self.qwen3_coder_detector.build_ebnf(self.tools)
...@@ -1919,5 +1944,164 @@ circle ...@@ -1919,5 +1944,164 @@ circle
self.assertEqual(params2["dimensions"], {"radius": 5}) self.assertEqual(params2["dimensions"], {"radius": 5})
class TestGlm4MoeDetector(unittest.TestCase):
def setUp(self):
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"type": "object",
"properties": {
"city": {"type": "string", "description": "City name"},
"date": {"type": "string", "description": "Date"},
},
"required": ["city", "date"],
},
),
),
]
self.detector = Glm4MoeDetector()
def test_single_tool_call(self):
text = (
"<tool_call>get_weather\n"
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n"
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n"
"</tool_call>"
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
)
self.assertEqual(result.normal_text, "")
def test_multiple_tool_calls(self):
text = (
"<tool_call>get_weather\n"
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n"
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n"
"</tool_call>"
"<tool_call>get_weather\n"
"<arg_key>city</arg_key>\n<arg_value>Shanghai</arg_value>\n"
"<arg_key>date</arg_key>\n<arg_value>2024-06-28</arg_value>\n"
"</tool_call>"
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 2)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
result.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
)
self.assertEqual(result.calls[1].name, "get_weather")
self.assertEqual(
result.calls[1].parameters, '{"city": "Shanghai", "date": "2024-06-28"}'
)
self.assertEqual(result.normal_text, "")
def test_streaming_tool_call(self):
"""Test streaming incremental parsing of a tool call."""
chunks = [
"<tool_call>get_weather\n",
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n",
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n",
"</tool_call>",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if (
hasattr(tool_call_chunk, "tool_index")
and tool_call_chunk.tool_index is not None
):
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": {}})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] = tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] = tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(
tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
)
def test_streaming_multiple_tool_calls(self):
"""Test streaming incremental parsing of multiple tool calls."""
chunks = [
"<tool_call>get_weather\n",
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n",
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n",
"</tool_call><tool_call>get_weather\n",
"<arg_key>city</arg_key>\n<arg_value>Shanghai</arg_value>\n",
"<arg_key>date</arg_key>\n<arg_value>2024-06-28</arg_value>\n",
"</tool_call>",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if (
hasattr(tool_call_chunk, "tool_index")
and tool_call_chunk.tool_index is not None
):
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": {}})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] = tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] = tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(
tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}'
)
self.assertEqual(tool_calls[1]["name"], "get_weather")
self.assertEqual(
tool_calls[1]["parameters"], '{"city": "Shanghai", "date": "2024-06-28"}'
)
def test_tool_call_completion(self):
"""Test that the buffer and state are reset after a tool call is completed."""
chunks = [
"<tool_call>get_weather\n",
"<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n",
"<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n",
"</tool_call>",
]
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
self.assertEqual(self.detector.current_tool_id, 1)
def test_invalid_tool_call(self):
"""Test that invalid tool calls are handled correctly."""
text = "<tool_call>invalid_func\n<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n</tool_call>"
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 0)
def test_partial_tool_call(self):
"""Test parsing a partial tool call that spans multiple chunks."""
text1 = "<tool_call>get_weather\n<arg_key>city</arg_key>\n"
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "")
self.assertEqual(result1.calls, [])
self.assertEqual(self.detector._buffer, text1)
text2 = "<arg_value>Beijing</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
self.assertEqual(
result2.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}'
)
self.assertEqual(self.detector._buffer, "")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment