"vscode:/vscode.git/clone" did not exist on "2c8b9182b5ced00d83bed15ef8bc0ac6e079b6ee"
Unverified Commit bb17e8f1 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

[GLM-OCR] GLM-OCR with MTP Support (#33005)


Signed-off-by: default avatarzRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent dcd80206
......@@ -674,6 +674,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ |
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ |
| `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + I<sup>E+</sup> | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ |
......
......@@ -566,6 +566,42 @@ def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData:
)
# GLM-OCR
def run_glm_ocr(questions: list[str], modality: str) -> ModelRequestData:
model_name = "zai-org/GLM-OCR"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
mm_processor_kwargs={
"size": {"shortest_edge": 12544, "longest_edge": 47040000},
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
if modality == "image":
placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
elif modality == "video":
placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
prompts = [
(
"[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n"
f"{placeholder}"
f"{question}<|assistant|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# H2OVL-Mississippi
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -1988,6 +2024,7 @@ model_example_map = {
"glm4_1v": run_glm4_1v,
"glm4_5v": run_glm4_5v,
"glm4_5v_fp8": run_glm4_5v_fp8,
"glm_ocr": run_glm_ocr,
"h2ovl_chat": run_h2ovl,
"hunyuan_vl": run_hunyuan_vl,
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
......@@ -2040,6 +2077,7 @@ model_example_map = {
MODELS_NEED_VIDEO_METADATA = [
"glm4_1v",
"glm_ocr",
"glm4_5v",
"glm4_5v_fp8",
"molmo2",
......
......@@ -458,6 +458,20 @@ VLM_TEST_SETTINGS = {
],
marks=[large_gpu_mark(min_gb=32)],
),
"glm_ocr": VLMTestInfo(
models=["zai-org/GLM-OCR"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>",
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>",
max_model_len=2048,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
auto_cls=AutoModelForImageTextToText,
marks=[large_gpu_mark(min_gb=32)],
),
"h2ovl": VLMTestInfo(
models=[
"h2oai/h2ovl-mississippi-800m",
......
......@@ -91,6 +91,19 @@ MODEL_CONFIGS: dict[str, dict[str, Any]] = {
"use_processor": True,
"question": "What is the content of each image?",
},
"glm_ocr": {
"model_name": "zai-org/GLM-OCR",
"interface": "llm_generate",
"max_model_len": 131072,
"max_num_seqs": 2,
"sampling_params": {
"temperature": 0.0,
"max_tokens": 256,
"stop_token_ids": None,
},
"use_processor": True,
"question": "Text Recognition:",
},
"keye_vl": {
"model_name": "Kwai-Keye/Keye-VL-8B-Preview",
"interface": "llm_generate",
......
......@@ -122,6 +122,7 @@ MM_DATA_PATCHES = {
"ernie4_5_moe_vl": qwen3_vl_patch_mm_data,
"glm4v": glm4_1v_patch_mm_data,
"glm4v_moe": glm4_1v_patch_mm_data,
"glm_ocr": glm4_1v_patch_mm_data,
"glmasr": glmasr_patch_mm_data,
"molmo2": qwen3_vl_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
......
......@@ -274,7 +274,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Glm4MoeLiteForCausalLM": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev",
is_available_online=False,
),
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
"GPTBigCodeForCausalLM": _HfExamplesInfo(
......@@ -707,6 +706,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"),
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"),
"GlmOcrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-OCR",
is_available_online=False,
min_transformers_version="5.0.0.dev",
),
"H2OVLChatModel": _HfExamplesInfo(
"h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
......@@ -1053,7 +1057,13 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.7-Flash",
speculative_model="zai-org/GLM-4.7-Flash",
min_transformers_version="5.0.0.dev",
),
"GlmOcrMTPModel": _HfExamplesInfo(
"zai-org/GLM-OCR",
speculative_model="zai-org/GLM-OCR",
is_available_online=False,
min_transformers_version="5.0.0.dev",
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
......
......@@ -34,6 +34,7 @@ MTPModelTypes = Literal[
"mimo_mtp",
"glm4_moe_mtp",
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
......@@ -221,6 +222,17 @@ class SpeculativeConfig:
}
)
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
hf_config.model_type = "glm_ocr_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["GlmOcrMTPModel"],
}
)
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
......
......@@ -39,13 +39,22 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import SupportsLoRA, SupportsPP
from .llama import LlamaMLP as Glm4MLP
from .llama import LlamaModel
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
maybe_prefix,
)
class Glm4Attention(nn.Module):
......@@ -78,7 +87,15 @@ class Glm4Attention(nn.Module):
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
rope_params = getattr(config, "rope_parameters", None)
if isinstance(rope_params, dict) and "partial_rotary_factor" in rope_params:
config.rope_parameters.setdefault(
"partial_rotary_factor", rope_params["partial_rotary_factor"]
)
else:
config.rope_parameters.setdefault("partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
......@@ -220,6 +237,73 @@ class Glm4Model(LlamaModel):
vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
......@@ -293,3 +377,16 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
def get_spec_layer_idx_from_weight_name(
config: Glm4Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if f"layers.{layer_idx + i}." in weight_name:
return layer_idx + i
return None
......@@ -24,7 +24,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
"""Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model
compatible with HuggingFace weights."""
import itertools
import math
......@@ -1418,7 +1419,7 @@ class Glm4vForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"),
)
if config.model_type == "glm4v":
if config.model_type in ("glm4v", "glm_ocr"):
architectures = ["Glm4ForCausalLM"]
elif config.model_type == "glm4v_moe":
architectures = ["Glm4MoeForCausalLM"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2026 The ZhipuAI Team.
# Copyright 2026 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-OCR model compatible with HuggingFace weights."""
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from einops import rearrange
if TYPE_CHECKING:
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mm_encoder_attention import (
MMEncoderAttention,
)
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.models.glm4_1v import (
Glm4vDummyInputsBuilder,
Glm4vForConditionalGeneration,
Glm4vMultiModalProcessor,
Glm4vPatchMerger,
Glm4vProcessingInfo,
Glm4vVisionBlock,
Glm4vVisionMLP,
Glm4vVisionPatchEmbed,
Glm4vVisionTransformer,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from .utils import (
maybe_prefix,
)
from .vision import (
get_vit_attn_backend,
is_vit_use_data_parallel,
)
logger = init_logger(__name__)
class GlmOcrVisionMLP(Glm4vVisionMLP):
pass
class GlmOcrVisionAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Per attention head and per partition values.
use_data_parallel = is_vit_use_data_parallel()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.tp_rank = (
0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank()
)
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size
)
self.head_dim = embed_dim // num_heads
self.q_norm = RMSNorm(self.head_dim, eps=1e-5)
self.k_norm = RMSNorm(self.head_dim, eps=1e-5)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
disable_tp=use_data_parallel,
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=True,
disable_tp=use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (
seq_len,
bs,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb_cos: torch.Tensor,
rotary_pos_emb_sin: torch.Tensor,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
# RMSNorm on q, k
q_shape, k_shape = q.shape, k.shape
q = self.q_norm(q.reshape(-1, self.head_dim)).view(q_shape)
k = self.k_norm(k.reshape(-1, self.head_dim)).view(k_shape)
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(
query=q,
key=k,
value=v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class GlmOcrVisionBlock(Glm4vVisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(
dim,
num_heads,
mlp_hidden_dim,
norm_layer,
quant_config,
prefix,
)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = GlmOcrVisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.mlp = GlmOcrVisionMLP(
dim,
mlp_hidden_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
class GlmOcrVisionPatchEmbed(Glm4vVisionPatchEmbed):
pass
class GlmOcrPatchMerger(Glm4vPatchMerger):
pass
class GlmOcrVisionTransformer(Glm4vVisionTransformer):
def __init__(
self,
vision_config: GlmOcrVisionConfig,
norm_eps: float = 1e-5,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix)
del self.post_conv_layernorm
del self.embeddings
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
in_channels = vision_config.in_channels
depth = vision_config.depth
self.hidden_size = vision_config.hidden_size
self.num_heads = vision_config.num_heads
self.patch_size = vision_config.patch_size
self.spatial_merge_size = vision_config.spatial_merge_size
self.out_hidden_size = vision_config.out_hidden_size
self.patch_embed = Glm4vVisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = get_rope(
head_size=head_dim,
max_position=8192,
is_neox_style=True,
rope_parameters={"partial_rotary_factor": 0.5},
)
self.blocks = nn.ModuleList(
[
GlmOcrVisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
]
)
self.merger = GlmOcrPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=vision_config.out_hidden_size * vision_config.in_channels,
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.merger",
)
self.downsample = Conv2dLayer(
in_channels=vision_config.hidden_size,
out_channels=vision_config.out_hidden_size,
kernel_size=vision_config.spatial_merge_size,
stride=vision_config.spatial_merge_size,
)
self.post_layernorm = RMSNorm(
vision_config.hidden_size, eps=vision_config.rms_norm_eps
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
dtype=torch.get_default_dtype(),
)
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb(
grid_thw
)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
)
# adapter
x = self.post_layernorm(x)
x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1])
x = x.permute(0, 3, 1, 2)
x = self.downsample(x).view(-1, self.out_hidden_size)
x = self.merger(x)
return x
@MULTIMODAL_REGISTRY.register_processor(
Glm4vMultiModalProcessor,
info=Glm4vProcessingInfo,
dummy_inputs=Glm4vDummyInputsBuilder,
)
class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = GlmOcrVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2026 The ZhipuAI Team.
# Copyright 2026 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-OCR MTP model compatible with HuggingFace weights."""
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .glm4 import Glm4DecoderLayer, get_spec_layer_idx_from_weight_name
from .glm4_moe_lite_mtp import (
Glm4MoeLiteMultiTokenPredictor,
SharedHead,
)
from .interfaces import SupportsPP
from .utils import (
is_pp_missing_parameter,
maybe_prefix,
)
class GlmOcrMultiTokenPredictorLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.speculative_config.draft_model_config.hf_config.text_config
self.config = config
quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.device = current_platform.device_type
self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config
)
self.mtp_block = Glm4DecoderLayer(
vllm_config=vllm_config, prefix=prefix, config=self.config
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions[0] == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
hidden_states, residual = self.mtp_block(
positions=positions, hidden_states=hidden_states, residual=None
)
hidden_states = residual + hidden_states
return hidden_states
class GlmOcrMultiTokenPredictor(Glm4MoeLiteMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config.text_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
self.layers = torch.nn.ModuleDict(
{
str(idx): GlmOcrMultiTokenPredictorLayer(
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
class GlmOcrMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.model = GlmOcrMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.expert_weights = []
self.num_layers = self.config.num_nextn_predict_layers
for layer in self.model.layers.values():
assert isinstance(layer, GlmOcrMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, Glm4DecoderLayer)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
spec_layer = self.model.mtp_start_layer_idx
name = f"model.layers.{spec_layer}.shared_head.head.weight"
elif name == "model.embed_tokens.weight":
spec_layer = self.model.mtp_start_layer_idx
else:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Some checkpoints include weight scale tensors for the
# LM head even when the quantized head isn't built. Skip
# them if the model does not expose a matching parameter
# to avoid KeyError during load.
if name.endswith(".weight_scale") and name not in params_dict:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
name = name.replace("model.language_model.layers", "model.layers")
spec_layer_weight_names = [
"embed_tokens",
"enorm",
"hnorm",
"eh_proj",
"shared_head",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
......@@ -319,8 +319,9 @@ _MULTIMODAL_MODELS = {
),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"),
"GlmOcrForConditionalGeneration": ("glm_ocr", "GlmOcrForConditionalGeneration"), # noqa: E501
"GraniteSpeechForConditionalGeneration": (
"granite_speech",
"GraniteSpeechForConditionalGeneration",
......@@ -472,6 +473,7 @@ _SPECULATIVE_DECODING_MODELS = {
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"Glm4MoeLiteMTPModel": ("glm4_moe_lite_mtp", "Glm4MoeLiteMTP"),
"GlmOcrMTPModel": ("glm_ocr_mtp", "GlmOcrMTP"),
"MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
......
......@@ -398,6 +398,7 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"qwen3_next_mtp": Qwen3NextMTPModelArchConfigConvertor,
"mimo_mtp": MimoMTPModelArchConfigConvertor,
"glm4_moe_mtp": GLM4MoeMTPModelArchConfigConvertor,
"glm_ocr_mtp": GLM4MoeMTPModelArchConfigConvertor,
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
......
......@@ -403,7 +403,7 @@ class SpecDecodeBaseProposer:
return draft_token_ids.view(-1, 1)
if self.uses_mrope:
positions = self.positions[:, last_token_indices]
positions = self.mrope_positions[:, last_token_indices]
else:
positions = self.positions[last_token_indices]
if self.method in (
......@@ -1126,6 +1126,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"GlmOcrForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment