Unverified Commit a88b006e authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

GLM-4-0414 and GLM-4.1V Code Refactor (#12117)

parent ce112c07
......@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
is_neox_style: tl.constexpr,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
......@@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_half_k_offsets = (
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
if is_neox_style:
first_half_q_offsets = (
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_half_k_offsets = (
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
)
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype
)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
# right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
q_tile_2 = tl.load(
q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
).to(sin_row.dtype)
k_tile_2 = tl.load(
k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
base_q = tl.arange(0, pad_n_qh)[:, None] * hd
base_k = tl.arange(0, pad_n_kh)[:, None] * hd
even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
odd_idx = even_idx + 1
even_q_offsets = base_q + even_idx
odd_q_offsets = base_q + odd_idx
even_k_offsets = base_k + even_idx
odd_k_offsets = base_k + odd_idx
idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
even_q_mask = qn_mask & idx_mask
odd_q_mask = qn_mask & idx_mask
even_k_mask = kn_mask & idx_mask
odd_k_mask = kn_mask & idx_mask
q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
sin_row.dtype
)
k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
sin_row.dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
# y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# NeoX-style rotary embedding:
# Each (even, odd) channel pair forms one rotation arm.
# cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
def triton_mrope(
......@@ -1180,6 +1229,7 @@ def triton_mrope(
head_size: int,
rotary_dim: int,
mrope_interleaved: bool,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""The mrope triton kernel.
......@@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section[1],
mrope_section[2],
mrope_interleaved,
is_neox_style,
)
return q, k
......@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
self.is_neox_style,
)
return q.reshape(query_shape), k.reshape(key_shape)
......
......@@ -15,46 +15,119 @@
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
"""Inference-only GLM4 model compatible with THUDM weights."""
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
from torch import nn
from transformers import Glm4Config
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaMLP as Glm4MLP
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import add_prefix, make_layers
Glm4Config = None
logger = logging.getLogger(__name__)
class Glm4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(
self,
x,
forward_batch=None,
use_reduce_scatter: bool = False,
):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x,
skip_all_reduce=use_reduce_scatter,
)
return x
class Glm4Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: Optional[int] = None,
layer_id: int = 0,
rope_theta: float = 1000000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 131072,
quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
partial_rotary_factor: float = 0.5,
prefix: str = "",
):
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
......@@ -63,27 +136,30 @@ class Glm4Attention(nn.Module):
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
if head_dim is not None:
self.head_dim = head_dim
else:
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = getattr(config, "rope_theta", 1000000)
self.rope_scaling = getattr(config, "rope_scaling", None)
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.partial_rotary_factor = partial_rotary_factor
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
bias=True,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
......@@ -92,9 +168,10 @@ class Glm4Attention(nn.Module):
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
partial_rotary_factor=partial_rotary_factor,
is_neox_style=False,
)
......@@ -117,14 +194,9 @@ class Glm4Attention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn(
q,
k,
v,
forward_batch,
)
attn_output, _ = self.o_proj(context_layer)
return attn_output
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class Glm4DecoderLayer(nn.Module):
......@@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module):
def __init__(
self,
config,
layer_id: int,
config: Glm4Config,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
# Self attention.
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
partial_rotary_factor = getattr(config, "partial_rotary_factor", None)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Glm4Attention(
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=head_dim,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config,
partial_rotary_factor=partial_rotary_factor,
prefix=add_prefix("self_attn", prefix),
)
# MLP
......@@ -199,54 +291,125 @@ class Glm4Model(nn.Module):
config: Glm4Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Glm4DecoderLayer,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
)
self.layers = make_layers(
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Glm4DecoderLayer
decoder_layer_type = decoder_layer_type or Glm4DecoderLayer
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Glm4DecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
lambda idx, prefix: decoder_layer_type(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
),
prefix="model.layers",
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# For EAGLE3 support
self.layers_to_capture = []
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
hidden_states = input_embeds
residual = None
for layer in self.layers:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states
return hidden_states, aux_hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling factor attribute!"
)
class Glm4ForCausalLM(nn.Module):
......@@ -255,21 +418,54 @@ class Glm4ForCausalLM(nn.Module):
config: Glm4Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
) -> None:
super().__init__()
self.config: Glm4Config = config
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
self.model = Glm4Model(config, quant_config, add_prefix("model", prefix))
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
self.model = Glm4Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
# handle the lm head on different pp ranks
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix="lm_head",
)
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()
# perform weight tying for PP
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
if self.pp_group.is_first_rank:
self.pp_group.send(
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
)
else:
emb_token_weight = self.pp_group.recv(
size=(config.vocab_size, config.hidden_size),
dtype=next(self.model.parameters()).dtype,
src=self.pp_group.first_rank,
)
self.lm_head.weight.copy_(emb_token_weight)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# For EAGLE3 support
self.capture_aux_hidden_states = False
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
def forward(
......@@ -277,34 +473,138 @@ class Glm4ForCausalLM(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return self.pooler(hidden_states, forward_batch)
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, weight_name, shard_id)
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".gate_proj", 0),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights = next(
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
)[1]
loaded_weight = embed_token_weights
else:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
......@@ -312,7 +612,21 @@ class Glm4ForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
else:
raise KeyError(f"Parameter '{name}' not found in model.")
logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
EntryClass = [Glm4ForCausalLM]
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from functools import lru_cache
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
......@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4 import Glm4Model
from sglang.srt.models.qwen2_5_vl import (
Qwen2_5_VisionBlock,
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
......@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
......@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
return x
class Glm4vVisionBlock(Qwen2_5_VisionBlock):
class Glm4vVisionBlock(nn.Module):
def __init__(
self,
config: Glm4vVisionConfig,
norm_layer: Optional[nn.Module] = None,
dim: int,
intermediate_dim: int,
num_heads: int,
attn_implementation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-5,
) -> None:
super().__init__(
dim=config.hidden_size,
intermediate_dim=config.out_hidden_size,
num_heads=config.num_heads,
hidden_act=config.hidden_act,
norm_layer=norm_layer,
super().__init__()
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
if attn_implementation is None:
softmax_in_single_precision = False
qkv_backend = None
flatten_batch = True
elif attn_implementation == "sdpa":
softmax_in_single_precision = False
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
qkv_backend = "triton_attn"
flatten_batch = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_3":
softmax_in_single_precision = False
qkv_backend = "fa3"
flatten_batch = True
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
quant_config=quant_config,
prefix=prefix,
num_dummy_heads=config.num_dummy_heads,
rms_norm_eps=config.rms_norm_eps,
prefix=add_prefix("attn", prefix),
num_dummy_heads=num_dummy_heads,
)
self.mlp = Glm4vVisionMLP(
config.hidden_size,
config.out_hidden_size,
bias=False,
dim,
intermediate_dim,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d = x.reshape(-1, H)
hidden_states = self.norm1(x2d).reshape(S, B, H)
# Attention expects [B, S, H]
hidden_states = rearrange(hidden_states, "s b h -> b s h")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s h -> s b h")
# norm2 with fused residual-add: also 2D
attn2d = attn.reshape(-1, H)
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
x_norm = x_norm_2d.reshape(S, B, H)
x_after_add = x_after_add_2d.reshape(S, B, H)
# MLP and final residual
mlp_out = self.mlp(x_norm)
x = x_after_add + mlp_out
return x
class Glm4vVisionPatchEmbed(nn.Module):
def __init__(
......@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
def __init__(
self,
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
......@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
hidden_size=self.hidden_size,
)
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Glm4vVisionBlock(
config=vision_config,
norm_layer=norm_layer,
dim=self.hidden_size,
intermediate_dim=self.out_hidden_size,
num_heads=self.num_heads,
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
rms_norm_eps=vision_config.rms_norm_eps,
)
for layer_idx in range(depth)
]
......@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
return x
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
class Glm4vForConditionalGeneration(nn.Module):
def __init__(
self,
config: Glm4vConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
super().__init__()
self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
self.visual = Glm4vVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
......@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
# For EAGLE3 support
self.capture_aux_hidden_states = False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0
......@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_embeds = torch.split(video_embeds, split_sizes)
return torch.cat(video_embeds)
def _update_hf_config(self):
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
tp_size = get_attention_tp_size()
num_heads = self.config.vision_config.num_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
def get_input_embeddings(self):
return self.model.embed_tokens
if num_heads % tp_size != 0:
num_dummy_heads = (
(num_heads + tp_size - 1) // tp_size
) * tp_size - num_heads
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for GLM-4.1V.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
)
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return self.pooler(hidden_states, forward_batch)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
......@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "language_model." in name:
name = name.replace("language_model.", "")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name:
continue
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
)
weight_loader(param, loaded_weight)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
del self.lm_head.weight
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
EntryClass = [Glm4vForConditionalGeneration]
......@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
self.visual = Glm4vVisionModel(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment