Unverified Commit 4060ed37 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

Refactoring GLM-4.5 and GLM-4.5V related implementations (#11800)

parent 2342605e
This diff is collapsed.
......@@ -12,7 +12,8 @@
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
"""Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
import logging
from typing import Iterable, Optional, Tuple
......@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
......@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=(
input_embeds.device if input_embeds is not None else input_ids.device
),
)
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
......@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
residual = None
with get_global_expert_distribution_recorder().disable_this_region():
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual, zero_allocator
positions, hidden_states, forward_batch, residual
)
if not forward_batch.forward_mode.is_idle():
......@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
def __init__(
self,
config: PretrainedConfig,
......@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
......
......@@ -6,13 +6,10 @@ import torch
import torch.nn as nn
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
......@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
from sglang.srt.utils import add_prefix, is_cuda
from sglang.srt.utils.hf_transformers_utils import get_processor
_is_cuda = is_cuda()
......@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
) -> None:
nn.Module.__init__(self)
config.moe_layer_freq = 1
self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
......@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# For EAGLE3 support
self.capture_aux_hidden_states = False
def determine_num_fused_shared_experts(
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason = None
if (
not _is_cuda
or torch.cuda.get_device_capability("cuda") < (8, 0)
or self.config.architectures[0] != architecture
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
)
return
self.num_fused_shared_experts = self.config.n_shared_experts
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
......@@ -130,116 +94,13 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.num_fused_shared_experts > 0:
assert self.num_fused_shared_experts == 1
weights_list = list(weights)
weights_dict = dict(weights_list)
if self.quant_config is not None:
if self.quant_config.get_name() == "w8a8_int8":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif (
self.quant_config.get_name() == "fp8"
or self.quant_config.get_name() == "blockwise_int8"
or self.quant_config.get_name() == "compressed_tensors"
):
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"up_proj.weight",
"up_proj.weight_scale",
]
elif self.quant_config.get_name() == "awq":
suffix_list = [
"down_proj.qweight",
"down_proj.qzeros",
"down_proj.scales",
"gate_proj.qweight",
"gate_proj.qzeros",
"gate_proj.scales",
"up_proj.qweight",
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
)
else:
suffix_list = [
"down_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
names_to_remove = []
moe_layers = (
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
)
if not is_nextn
else [nextn_layer_id]
)
for moe_layer in moe_layers:
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
# online fp8 quantization does not load weight_scale
if shared_expert_weight_name not in weights_dict:
continue
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + 0}"
f".{suffix}",
weights_dict[shared_expert_weight_name],
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
num_experts=self.config.n_routed_experts,
)
cached_a_proj = {} if fuse_qkv_a_proj else None
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
......@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
......@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if "visual" in name:
# adapt to VisionAttention
# adapt to VisionAttention for GLM-V
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
if name not in params_dict:
continue
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
......@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.config, name, loaded_weight
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [Glm4vMoeForConditionalGeneration]
......@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# GLM-4.1V and GLM-4.5V specific tokens
# GLM-V specific tokens
self.IMAGE_TOKEN = "<|image|>"
self.VIDEO_TOKEN = "<|video|>"
self.IMAGE_START_TOKEN = "<|begin_of_image|>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment