Unverified Commit 4060ed37 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

Refactoring GLM-4.5 and GLM-4.5V related implementations (#11800)

parent 2342605e
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights""" """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -27,10 +27,16 @@ from sglang.srt.distributed import ( ...@@ -27,10 +27,16 @@ from sglang.srt.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
LayerCommunicator, LayerCommunicator,
LayerScatterModes, LayerScatterModes,
...@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import ( ...@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV2ForCausalLM,
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator,
LazyValue,
add_prefix, add_prefix,
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -80,8 +83,7 @@ from sglang.srt.utils import ( ...@@ -80,8 +83,7 @@ from sglang.srt.utils import (
is_cpu, is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
log_info_on_rank0, make_layers,
use_intel_amx_backend,
) )
_is_hip = is_hip() _is_hip = is_hip()
...@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support() ...@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_device_sm = get_device_sm() _device_sm = get_device_sm()
if _is_cuda:
from sgl_kernel import dsv3_router_gemm
elif _is_cpu and _is_cpu_amx_available:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module): ...@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
f"Unsupported activation: {hidden_act}. " f"Unsupported activation: {hidden_act}. Only silu is supported for now."
"Only silu is supported for now."
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module): ...@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
x, x,
forward_batch=None, forward_batch=None,
should_allreduce_fusion=False, should_allreduce_fusion=False,
gemm_output_zero_allocator: BumpAllocator = None,
): ):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
return x return x
...@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module): ...@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module):
self, self,
config, config,
prefix: str = "", prefix: str = "",
is_nextn: bool = False,
): ):
super().__init__() super().__init__()
self.is_nextn = is_nextn
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size)) torch.empty((config.n_routed_experts, config.hidden_size))
) )
self.e_score_correction_bias = nn.Parameter( self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts), dtype=torch.float32) torch.empty((config.n_routed_experts), dtype=torch.float32)
) )
if _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"])
def forward(self, hidden_states): def forward(self, hidden_states):
if use_intel_amx_backend(self): logits = F.linear(hidden_states, self.weight, None)
return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states,
self.weight,
None, # bias
True, # is_vnni
)
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if (
_is_cuda
and not self.is_nextn
and hidden_states.shape[0] < 4
and hidden_states.shape[1] == 7168
and self.weight.shape[0] == 256
and _device_sm >= 90
):
logits = dsv3_router_gemm(hidden_states, self.weight).to(
hidden_states.dtype
)
else:
logits = F.linear(hidden_states, self.weight, None)
return logits return logits
class Glm4MoeSparseMoeBlock(DeepseekV2MoE): class Glm4MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None, alt_stream: Optional[torch.cuda.Stream] = None,
is_nextn: bool = False,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
self.top_k = config.num_experts_per_tok
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
self.config = config self.config = config
self.layer_id = layer_id self.layer_id = layer_id
self.alt_stream = alt_stream self.alt_stream = alt_stream
...@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
"Only silu is supported for now." "Only silu is supported for now."
) )
self.gate = Glm4MoeGate( self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix))
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = TopK( self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=self.top_k,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
use_grouped_topk=True, use_grouped_topk=True,
num_expert_group=config.n_group, num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group, topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts,
+ self.num_fused_shared_experts top_k=self.top_k,
+ get_global_server_args().ep_num_redundant_experts, layer_id=self.layer_id,
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
quant_config=quant_config, quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )
self.shared_experts_is_int8 = False # shared expert
self.shared_experts_is_fp8 = False if config.n_shared_experts is not None:
# self.shared_experts_weight_block_size = None
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Glm4MoeMLP( self.shared_experts = Glm4MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix("shared_experts", prefix), prefix=add_prefix("shared_experts", prefix),
**(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}), **(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
or get_moe_a2a_backend().is_mooncake()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else {}
),
) )
is_packed_weight = hasattr(
self.shared_experts.gate_up_proj.quant_method, "quant_config"
)
self.shared_experts_is_int8 = (
not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
)
self.shared_experts_is_fp8 = (
not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
)
self.top_k = config.num_experts_per_tok
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
...@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
) )
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not self._enable_a2a_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_normal(
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_deepep(hidden_states, forward_batch)
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
if self.ep_size > 1: torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
if ( final_hidden_states = final_hidden_states_out
self.tp_size > 1 sm.tag(final_hidden_states)
and not should_allreduce_fusion if (
and not use_reduce_scatter self.tp_size > 1
): and not should_allreduce_fusion
final_hidden_states = tensor_model_parallel_all_reduce( and not use_reduce_scatter
final_hidden_states and not should_use_flashinfer_cutlass_moe_fp4_allgather()
) ):
final_hidden_states += shared_output final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
else:
final_hidden_states += shared_output
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
return final_hidden_states return final_hidden_states
def forward_normal( def forward_normal(
...@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hidden_states.shape[0] > 0:
self.shared_experts.gate_up_proj shared_output = self._forward_shared_experts(hidden_states)
): # router_logits: (num_tokens, n_experts)
return self.forward_cpu(hidden_states, should_allreduce_fusion) router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if self.ep_size > 1: if shared_output is not None:
if self.tp_size > 1 and not should_allreduce_fusion: with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states_out = torch.empty_like(final_hidden_states)
final_hidden_states torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
) final_hidden_states = final_hidden_states_out
if shared_output is not None: sm.tag(final_hidden_states)
final_hidden_states += shared_output if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
shared_output = None
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_output = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else: else:
if shared_output is not None: topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states += shared_output final_hidden_states = self.experts(
if self.tp_size > 1 and not should_allreduce_fusion: hidden_states=hidden_states,
final_hidden_states = tensor_model_parallel_all_reduce( topk_output=topk_output,
final_hidden_states )
)
if shared_output is not None:
final_hidden_states.add_(shared_output)
return final_hidden_states return final_hidden_states
def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None
if hidden_states.shape[0] > 0:
shared_output = self.shared_experts(hidden_states)
return shared_output
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
class Glm4MoeDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
rms_norm_eps = config.rms_norm_eps rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias attention_bias = config.attention_bias
self.layer_id = layer_id self.layer_id = layer_id
self.self_attn = Glm4MoeAttention( self.self_attn = Glm4MoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
...@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
use_qk_norm=config.use_qk_norm, use_qk_norm=config.use_qk_norm,
alt_stream=alt_stream,
) )
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn) self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False) is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
num_layers = 1 if is_nextn else config.num_hidden_layers
self.layer_scatter_modes = LayerScatterModes.init_new( self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id, layer_id=layer_id,
num_layers=num_layers, num_layers=1 if is_nextn else config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse, is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse, is_previous_layer_sparse=is_previous_layer_sparse,
) )
...@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id, layer_id=self.layer_id,
alt_stream=alt_stream,
) )
else: else:
if enable_moe_dense_fully_dp(): if enable_moe_dense_fully_dp():
...@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=False, allow_reduce_scatter=True,
is_last_layer=(
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
),
)
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or (
self.config.n_routed_experts is not None
and layer_id >= self.config.first_k_dense_replace
) )
def forward( def forward(
...@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
...@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
return hidden_states, residual return hidden_states, residual
class Glm4MoeModel(DeepseekV2Model): class Glm4MoeModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ):
nn.Module.__init__(self) super().__init__()
self.padding_id = config.pad_token_id self.pp_group = get_pp_group()
self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace self.embed_dim = config.hidden_size
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
)
else:
self.embed_tokens = PPMissingLayer()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList( self.layers, self.start_layer, self.end_layer = make_layers(
[ config.num_hidden_layers,
Glm4MoeDecoderLayer( lambda idx, prefix: Glm4MoeDecoderLayer(
config, layer_id=idx,
layer_id, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix), prefix=prefix,
alt_stream=self.alt_stream, alt_stream=self.alt_stream,
) ),
for layer_id in range(config.num_hidden_layers) pp_rank=self.pp_group.rank_in_group,
] pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
) )
self.pp_group = get_pp_group() if self.pp_group.is_last_rank:
self.start_layer = 0 self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
self.end_layer = config.num_hidden_layers else:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = PPMissingLayer(return_tuple=True)
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
normal_start_layer = self.start_layer
normal_end_layer = self.end_layer
if forward_batch.can_run_tbo:
if (
self.first_k_dense_replace > normal_start_layer
and self.first_k_dense_replace < normal_end_layer
):
normal_end_layer = self.first_k_dense_replace
elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0
for i in range(normal_start_layer, normal_end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
if normal_end_layer != self.end_layer:
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers[normal_end_layer : self.end_layer],
enable_tbo=True,
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
input_data_scatter_mode=self.layers[
normal_end_layer - 1
].layer_scatter_modes.layer_output_mode,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Glm4MoeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.model = Glm4MoeModel( self.model = Glm4MoeModel(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
...@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self._routed_experts_weights_of_layer = LazyValue( # For EAGLE3 support
lambda: { self.capture_aux_hidden_states = False
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
)
def determine_num_fused_shared_experts( def get_input_embeddings(self) -> nn.Embedding:
self, architecture: str = "Glm4MoeForCausalLM" return self.model.embed_tokens
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. @torch.no_grad()
disable_reason = None def forward(
if ( self,
not _is_cuda input_ids: torch.Tensor,
or torch.cuda.get_device_capability("cuda") < (8, 0) positions: torch.Tensor,
or self.config.architectures[0] != architecture forward_batch: ForwardBatch,
or self.config.n_shared_experts != 1 input_embeds: torch.Tensor = None,
): pp_proxy_tensors: Optional[PPProxyTensors] = None,
disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization." ) -> torch.Tensor:
elif get_moe_expert_parallel_world_size() > 1: hidden_states = self.model(
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism." input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
)
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True if self.pp_group.is_last_rank:
self.num_fused_shared_experts = 0 return self.logits_processor(
log_info_on_rank0( input_ids, hidden_states, self.lm_head, forward_batch
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
) )
return else:
return hidden_states
self.num_fused_shared_experts = self.config.n_shared_experts @property
def start_layer(self):
return self.model.start_layer
def get_input_embeddings(self) -> nn.Embedding: @property
return self.model.embed_tokens def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn: if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"): if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers num_nextn_layers = self.config.num_nextn_predict_layers
...@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
or self.quant_config.get_name() == "compressed_tensors"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, num_experts=self.config.n_routed_experts,
) )
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn: if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}" nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [ nextn_spec_weight_names = [
...@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# name will be updated to mlp.experts[0].gate_up_proj, which # name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping # will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if "mlp.experts" in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
...@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): ...@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
) )
break break
else: else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if fuse_qkv_a_proj and ( if name not in params_dict:
"q_a_proj" in name or "kv_a_proj_with_mqa" in name continue
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter if name in params_dict.keys():
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.n_routed_experts,
num_groups=config.n_group,
)
EntryClass = [Glm4MoeForCausalLM] EntryClass = [Glm4MoeForCausalLM]
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding.""" """Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
import logging import logging
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
...@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module): ...@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
...@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module): ...@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
residual = None residual = None
with get_global_expert_distribution_recorder().disable_this_region(): with get_global_expert_distribution_recorder().disable_this_region():
hidden_states, residual = self.decoder( hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator positions, hidden_states, forward_batch, residual
) )
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
...@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module): ...@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): ...@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
self.model = Glm4MoeModelNextN( self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
......
...@@ -6,13 +6,10 @@ import torch ...@@ -6,13 +6,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import get_tensor_model_parallel_world_size
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
...@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0 from sglang.srt.utils import add_prefix, is_cuda
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config) vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if get_global_server_args().disable_shared_experts_fusion if get_global_server_args().disable_shared_experts_fusion
...@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# For EAGLE3 support # For EAGLE3 support
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
def determine_num_fused_shared_experts(
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason = None
if (
not _is_cuda
or torch.cuda.get_device_capability("cuda") < (8, 0)
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
)
return
self.num_fused_shared_experts = self.config.n_shared_experts
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn: if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"): if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers num_nextn_layers = self.config.num_nextn_predict_layers
...@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
or self.quant_config.get_name() == "compressed_tensors"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, num_experts=self.config.n_routed_experts,
) )
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn: if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}" nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [ nextn_spec_weight_names = [
...@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# name will be updated to mlp.experts[0].gate_up_proj, which # name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping # will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict: if "mlp.experts" in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
...@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
) )
break break
else: else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name: if "visual" in name:
# adapt to VisionAttention # adapt to VisionAttention for GLM-V
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if fuse_qkv_a_proj and ( if name not in params_dict:
"q_a_proj" in name or "kv_a_proj_with_mqa" in name continue
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
weight_loader = getattr( if name in params_dict.keys():
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
...@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.config, name, loaded_weight self.config, name, loaded_weight
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [Glm4vMoeForConditionalGeneration] EntryClass = [Glm4vMoeForConditionalGeneration]
...@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor): ...@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# GLM-4.1V and GLM-4.5V specific tokens # GLM-V specific tokens
self.IMAGE_TOKEN = "<|image|>" self.IMAGE_TOKEN = "<|image|>"
self.VIDEO_TOKEN = "<|video|>" self.VIDEO_TOKEN = "<|video|>"
self.IMAGE_START_TOKEN = "<|begin_of_image|>" self.IMAGE_START_TOKEN = "<|begin_of_image|>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment