Unverified Commit a88b006e authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

GLM-4-0414 and GLM-4.1V Code Refactor (#12117)

parent ce112c07
......@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
is_neox_style: tl.constexpr,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
......@@ -1124,6 +1125,7 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
if is_neox_style:
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
......@@ -1150,12 +1152,12 @@ def _triton_mrope_forward(
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(
q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
).to(sin_row.dtype)
k_tile_2 = tl.load(
k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
......@@ -1169,6 +1171,53 @@ def _triton_mrope_forward(
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
base_q = tl.arange(0, pad_n_qh)[:, None] * hd
base_k = tl.arange(0, pad_n_kh)[:, None] * hd
even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
odd_idx = even_idx + 1
even_q_offsets = base_q + even_idx
odd_q_offsets = base_q + odd_idx
even_k_offsets = base_k + even_idx
odd_k_offsets = base_k + odd_idx
idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
even_q_mask = qn_mask & idx_mask
odd_q_mask = qn_mask & idx_mask
even_k_mask = kn_mask & idx_mask
odd_k_mask = kn_mask & idx_mask
q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
sin_row.dtype
)
# y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# NeoX-style rotary embedding:
# Each (even, odd) channel pair forms one rotation arm.
# cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
def triton_mrope(
......@@ -1180,6 +1229,7 @@ def triton_mrope(
head_size: int,
rotary_dim: int,
mrope_interleaved: bool,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""The mrope triton kernel.
......@@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section[1],
mrope_section[2],
mrope_interleaved,
is_neox_style,
)
return q, k
......@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
self.is_neox_style,
)
return q.reshape(query_shape), k.reshape(key_shape)
......
This diff is collapsed.
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from functools import lru_cache
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
......@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4 import Glm4Model
from sglang.srt.models.qwen2_5_vl import (
Qwen2_5_VisionBlock,
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
......@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
......@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
return x
class Glm4vVisionBlock(Qwen2_5_VisionBlock):
class Glm4vVisionBlock(nn.Module):
def __init__(
self,
config: Glm4vVisionConfig,
norm_layer: Optional[nn.Module] = None,
dim: int,
intermediate_dim: int,
num_heads: int,
attn_implementation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-5,
) -> None:
super().__init__(
dim=config.hidden_size,
intermediate_dim=config.out_hidden_size,
num_heads=config.num_heads,
hidden_act=config.hidden_act,
norm_layer=norm_layer,
super().__init__()
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
if attn_implementation is None:
softmax_in_single_precision = False
qkv_backend = None
flatten_batch = True
elif attn_implementation == "sdpa":
softmax_in_single_precision = False
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
qkv_backend = "triton_attn"
flatten_batch = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_3":
softmax_in_single_precision = False
qkv_backend = "fa3"
flatten_batch = True
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
quant_config=quant_config,
prefix=prefix,
num_dummy_heads=config.num_dummy_heads,
rms_norm_eps=config.rms_norm_eps,
prefix=add_prefix("attn", prefix),
num_dummy_heads=num_dummy_heads,
)
self.mlp = Glm4vVisionMLP(
config.hidden_size,
config.out_hidden_size,
bias=False,
dim,
intermediate_dim,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d = x.reshape(-1, H)
hidden_states = self.norm1(x2d).reshape(S, B, H)
# Attention expects [B, S, H]
hidden_states = rearrange(hidden_states, "s b h -> b s h")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s h -> s b h")
# norm2 with fused residual-add: also 2D
attn2d = attn.reshape(-1, H)
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
x_norm = x_norm_2d.reshape(S, B, H)
x_after_add = x_after_add_2d.reshape(S, B, H)
# MLP and final residual
mlp_out = self.mlp(x_norm)
x = x_after_add + mlp_out
return x
class Glm4vVisionPatchEmbed(nn.Module):
def __init__(
......@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
def __init__(
self,
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
......@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
hidden_size=self.hidden_size,
)
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Glm4vVisionBlock(
config=vision_config,
norm_layer=norm_layer,
dim=self.hidden_size,
intermediate_dim=self.out_hidden_size,
num_heads=self.num_heads,
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
rms_norm_eps=vision_config.rms_norm_eps,
)
for layer_idx in range(depth)
]
......@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
return x
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
class Glm4vForConditionalGeneration(nn.Module):
def __init__(
self,
config: Glm4vConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
super().__init__()
self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
self.visual = Glm4vVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
......@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
# For EAGLE3 support
self.capture_aux_hidden_states = False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0
......@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_embeds = torch.split(video_embeds, split_sizes)
return torch.cat(video_embeds)
def _update_hf_config(self):
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
tp_size = get_attention_tp_size()
num_heads = self.config.vision_config.num_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
def get_input_embeddings(self):
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for GLM-4.1V.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
)
if num_heads % tp_size != 0:
num_dummy_heads = (
(num_heads + tp_size - 1) // tp_size
) * tp_size - num_heads
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return self.pooler(hidden_states, forward_batch)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
......@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "language_model." in name:
name = name.replace("language_model.", "")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name:
continue
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)
weight_loader(param, loaded_weight)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
del self.lm_head.weight
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
EntryClass = [Glm4vForConditionalGeneration]
......@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
self.visual = Glm4vVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment