"examples/pytorch/vscode:/vscode.git/clone" did not exist on "0c3b2b780f7ea6318ead8fcbe23e973467e4ca01"
Unverified Commit 4060ed37 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

Refactoring GLM-4.5 and GLM-4.5V related implementations (#11800)

parent 2342605e
......@@ -15,7 +15,7 @@
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -27,10 +27,16 @@ from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
......@@ -48,7 +54,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
......@@ -56,23 +65,17 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV2ForCausalLM,
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import (
BumpAllocator,
LazyValue,
add_prefix,
cpu_has_amx_support,
get_bool_env_var,
......@@ -80,8 +83,7 @@ from sglang.srt.utils import (
is_cpu,
is_cuda,
is_hip,
log_info_on_rank0,
use_intel_amx_backend,
make_layers,
)
_is_hip = is_hip()
......@@ -92,11 +94,6 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_device_sm = get_device_sm()
if _is_cuda:
from sgl_kernel import dsv3_router_gemm
elif _is_cpu and _is_cpu_amx_available:
pass
logger = logging.getLogger(__name__)
......@@ -136,8 +133,7 @@ class Glm4MoeMLP(nn.Module):
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
......@@ -146,7 +142,6 @@ class Glm4MoeMLP(nn.Module):
x,
forward_batch=None,
should_allreduce_fusion=False,
gemm_output_zero_allocator: BumpAllocator = None,
):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
......@@ -326,47 +321,21 @@ class Glm4MoeGate(nn.Module):
self,
config,
prefix: str = "",
is_nextn: bool = False,
):
super().__init__()
self.is_nextn = is_nextn
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
)
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts), dtype=torch.float32)
)
if _is_cpu and _is_cpu_amx_available:
self.quant_method = PackWeightMethod(weight_names=["weight"])
def forward(self, hidden_states):
if use_intel_amx_backend(self):
return torch.ops.sgl_kernel.weight_packed_linear(
hidden_states,
self.weight,
None, # bias
True, # is_vnni
)
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if (
_is_cuda
and not self.is_nextn
and hidden_states.shape[0] < 4
and hidden_states.shape[1] == 7168
and self.weight.shape[0] == 256
and _device_sm >= 90
):
logits = dsv3_router_gemm(hidden_states, self.weight).to(
hidden_states.dtype
)
else:
logits = F.linear(hidden_states, self.weight, None)
logits = F.linear(hidden_states, self.weight, None)
return logits
class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
class Glm4MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -374,18 +343,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
is_nextn: bool = False,
):
nn.Module.__init__(self)
self.top_k = config.num_experts_per_tok
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
self.config = config
self.layer_id = layer_id
self.alt_stream = alt_stream
......@@ -402,39 +365,31 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
"Only silu is supported for now."
)
self.gate = Glm4MoeGate(
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix))
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
top_k=self.top_k,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
num_experts=config.n_routed_experts,
top_k=self.top_k,
layer_id=self.layer_id,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
layer_id=self.layer_id,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
)
self.shared_experts_is_int8 = False
self.shared_experts_is_fp8 = False
# self.shared_experts_weight_block_size = None
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
# shared expert
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = Glm4MoeMLP(
hidden_size=config.hidden_size,
......@@ -443,21 +398,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
**(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
or get_moe_a2a_backend().is_mooncake()
or should_use_flashinfer_cutlass_moe_fp4_allgather()
else {}
),
)
is_packed_weight = hasattr(
self.shared_experts.gate_up_proj.quant_method, "quant_config"
)
self.shared_experts_is_int8 = (
not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
)
self.shared_experts_is_fp8 = (
not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
)
self.top_k = config.num_experts_per_tok
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
# TODO: we will support tp < ep in the future
......@@ -479,12 +427,46 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
)
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not self._enable_a2a_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_normal(
hidden_states,
should_allreduce_fusion,
use_reduce_scatter,
)
else:
return self.forward_deepep(hidden_states, forward_batch)
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
......@@ -498,28 +480,21 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
if self.ep_size > 1:
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
final_hidden_states += shared_output
else:
final_hidden_states += shared_output
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_normal(
......@@ -527,39 +502,69 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
):
return self.forward_cpu(hidden_states, should_allreduce_fusion)
if hidden_states.shape[0] > 0:
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if self.ep_size > 1:
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
if shared_output is not None:
final_hidden_states += shared_output
if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
shared_output = None
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_output = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
if shared_output is not None:
final_hidden_states += shared_output
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_output=topk_output,
)
if shared_output is not None:
final_hidden_states.add_(shared_output)
return final_hidden_states
def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None
if hidden_states.shape[0] > 0:
shared_output = self.shared_experts(hidden_states)
return shared_output
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
class Glm4MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -582,6 +587,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
self.layer_id = layer_id
self.self_attn = Glm4MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -597,15 +603,15 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
use_qk_norm=config.use_qk_norm,
alt_stream=alt_stream,
)
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
num_layers = 1 if is_nextn else config.num_hidden_layers
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=num_layers,
num_layers=1 if is_nextn else config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
......@@ -616,6 +622,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
layer_id=self.layer_id,
alt_stream=alt_stream,
)
else:
if enable_moe_dense_fully_dp():
......@@ -641,7 +648,16 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=False,
allow_reduce_scatter=True,
is_last_layer=(
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
),
)
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or (
self.config.n_routed_experts is not None
and layer_id >= self.config.first_k_dense_replace
)
def forward(
......@@ -650,8 +666,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
......@@ -676,44 +690,119 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
return hidden_states, residual
class Glm4MoeModel(DeepseekV2Model):
class Glm4MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.padding_id = config.pad_token_id
):
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.vocab_size = config.vocab_size
self.first_k_dense_replace = config.first_k_dense_replace
self.embed_dim = config.hidden_size
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
)
else:
self.embed_tokens = PPMissingLayer()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not is_dp_attention_enabled(),
)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList(
[
Glm4MoeDecoderLayer(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
alt_stream=self.alt_stream,
)
for layer_id in range(config.num_hidden_layers)
]
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Glm4MoeDecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=self.alt_stream,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
self.pp_group = get_pp_group()
self.start_layer = 0
self.end_layer = config.num_hidden_layers
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
normal_start_layer = self.start_layer
normal_end_layer = self.end_layer
if forward_batch.can_run_tbo:
if (
self.first_k_dense_replace > normal_start_layer
and self.first_k_dense_replace < normal_end_layer
):
normal_end_layer = self.first_k_dense_replace
elif self.first_k_dense_replace < normal_start_layer:
normal_end_layer = normal_start_layer = 0
for i in range(normal_start_layer, normal_end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
if normal_end_layer != self.end_layer:
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers[normal_end_layer : self.end_layer],
enable_tbo=True,
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
residual=residual,
input_data_scatter_mode=self.layers[
normal_end_layer - 1
].layer_scatter_modes.layer_output_mode,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Glm4MoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -721,12 +810,10 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
prefix: str = "",
) -> None:
nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.model = Glm4MoeModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
......@@ -739,49 +826,41 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
)
self.logits_processor = LogitsProcessor(config)
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, DeepseekV2MoE)
}
)
# For EAGLE3 support
self.capture_aux_hidden_states = False
def determine_num_fused_shared_experts(
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
return
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason = None
if (
not _is_cuda
or torch.cuda.get_device_capability("cuda") < (8, 0)
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
)
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
return
else:
return hidden_states
self.num_fused_shared_experts = self.config.n_shared_experts
@property
def start_layer(self):
return self.model.start_layer
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
......@@ -803,117 +882,14 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
or self.quant_config.get_name() == "compressed_tensors"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
num_experts=self.config.n_routed_experts,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
......@@ -969,22 +945,36 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
......@@ -996,65 +986,43 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
)
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
if name not in params_dict:
continue
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.n_routed_experts,
num_groups=config.n_group,
)
EntryClass = [Glm4MoeForCausalLM]
......@@ -12,7 +12,8 @@
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
"""Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
......@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
......@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
......@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator
positions, hidden_states, forward_batch, residual
)
if not forward_batch.forward_mode.is_idle():
......@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
def __init__(
self,
config: PretrainedConfig,
......@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
......
......@@ -6,13 +6,10 @@ import torch
import torch.nn as nn
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
......@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
from sglang.srt.utils import add_prefix, is_cuda
from sglang.srt.utils.hf_transformers_utils import get_processor
_is_cuda = is_cuda()
......@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
) -> None:
nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
......@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# For EAGLE3 support
self.capture_aux_hidden_states = False
def determine_num_fused_shared_experts(
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason = None
if (
not _is_cuda
or torch.cuda.get_device_capability("cuda") < (8, 0)
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
)
return
self.num_fused_shared_experts = self.config.n_shared_experts
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
......@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
or self.quant_config.get_name() == "compressed_tensors"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
num_experts=self.config.n_routed_experts,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
......@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
......@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name:
# adapt to VisionAttention
# adapt to VisionAttention for GLM-V
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
if name not in params_dict:
continue
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
......@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.config, name, loaded_weight
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [Glm4vMoeForConditionalGeneration]
......@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# GLM-4.1V and GLM-4.5V specific tokens
# GLM-V specific tokens
self.IMAGE_TOKEN = "<|image|>"
self.VIDEO_TOKEN = "<|video|>"
self.IMAGE_START_TOKEN = "<|begin_of_image|>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment