Unverified Commit a88b006e authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

GLM-4-0414 and GLM-4.1V Code Refactor (#12117)

parent ce112c07
...@@ -1070,6 +1070,7 @@ def _triton_mrope_forward( ...@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h: tl.constexpr, mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr, mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr, is_interleaved: tl.constexpr,
is_neox_style: tl.constexpr,
): ):
# Adapted from # Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
...@@ -1124,51 +1125,99 @@ def _triton_mrope_forward( ...@@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately # program instance (i.e. for the current token) separately
# #################################################################### # ####################################################################
# left half of the head # left half of the head
first_half_q_offsets = ( if is_neox_style:
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] first_half_q_offsets = (
) tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = ( )
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] first_half_k_offsets = (
) tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( )
tl.arange(0, pad_hd // 2)[None, :] < rd // 2 first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
) tl.arange(0, pad_hd // 2)[None, :] < rd // 2
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( )
tl.arange(0, pad_hd // 2)[None, :] < rd // 2 first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
) tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
# right half of the head # right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2) second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2) second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask second_q_mask = first_q_mask
second_k_mask = first_k_mask second_k_mask = first_k_mask
q_tile_2 = tl.load(
q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
).to(sin_row.dtype)
k_tile_2 = tl.load(
k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
base_q = tl.arange(0, pad_n_qh)[:, None] * hd
base_k = tl.arange(0, pad_n_kh)[:, None] * hd
even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
odd_idx = even_idx + 1
even_q_offsets = base_q + even_idx
odd_q_offsets = base_q + odd_idx
even_k_offsets = base_k + even_idx
odd_k_offsets = base_k + odd_idx
idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
even_q_mask = qn_mask & idx_mask
odd_q_mask = qn_mask & idx_mask
even_k_mask = kn_mask & idx_mask
odd_k_mask = kn_mask & idx_mask
q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] # y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# Since cos and sin are now half-size, # NeoX-style rotary embedding:
# we use the same cos_row and sin_row for both halves # Each (even, odd) channel pair forms one rotation arm.
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row # cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
def triton_mrope( def triton_mrope(
...@@ -1180,6 +1229,7 @@ def triton_mrope( ...@@ -1180,6 +1229,7 @@ def triton_mrope(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
mrope_interleaved: bool, mrope_interleaved: bool,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""The mrope triton kernel. """The mrope triton kernel.
...@@ -1230,6 +1280,7 @@ def triton_mrope( ...@@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section[1], mrope_section[1],
mrope_section[2], mrope_section[2],
mrope_interleaved, mrope_interleaved,
is_neox_style,
) )
return q, k return q, k
...@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.head_size, self.head_size,
self.rotary_dim, self.rotary_dim,
self.mrope_interleaved, self.mrope_interleaved,
self.is_neox_style,
) )
return q.reshape(query_shape), k.reshape(key_shape) return q.reshape(query_shape), k.reshape(key_shape)
......
...@@ -15,46 +15,119 @@ ...@@ -15,46 +15,119 @@
# Modeling from: # Modeling from:
# ./llama.py and # ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
"""Inference-only GLM4 model compatible with THUDM weights.""" """Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple, Union import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Glm4Config
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import (
from sglang.srt.models.llama import LlamaMLP as Glm4MLP default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
Glm4Config = None
logger = logging.getLogger(__name__)
class Glm4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(
self,
x,
forward_batch=None,
use_reduce_scatter: bool = False,
):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x,
skip_all_reduce=use_reduce_scatter,
)
return x
class Glm4Attention(nn.Module): class Glm4Attention(nn.Module):
def __init__( def __init__(
self, self,
config, hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: Optional[int] = None,
layer_id: int = 0, layer_id: int = 0,
rope_theta: float = 1000000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 131072,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
partial_rotary_factor: float = 0.5,
prefix: str = "", prefix: str = "",
): ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size: if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition # Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
...@@ -63,27 +136,30 @@ class Glm4Attention(nn.Module): ...@@ -63,27 +136,30 @@ class Glm4Attention(nn.Module):
# Number of KV heads is less than TP size, so we replicate # Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads if head_dim is not None:
self.head_dim = head_dim
else:
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = getattr(config, "rope_theta", 1000000) self.rope_theta = rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None) self.max_position_embeddings = max_position_embeddings
self.partial_rotary_factor = partial_rotary_factor
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
self.hidden_size, hidden_size,
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.attention_bias, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix), prefix=add_prefix("o_proj", prefix),
...@@ -92,9 +168,10 @@ class Glm4Attention(nn.Module): ...@@ -92,9 +168,10 @@ class Glm4Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=config.max_position_embeddings, max_position=max_position_embeddings,
base=self.rope_theta, base=rope_theta,
rope_scaling=self.rope_scaling, rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
partial_rotary_factor=partial_rotary_factor, partial_rotary_factor=partial_rotary_factor,
is_neox_style=False, is_neox_style=False,
) )
...@@ -117,14 +194,9 @@ class Glm4Attention(nn.Module): ...@@ -117,14 +194,9 @@ class Glm4Attention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn( attn_output = self.attn(q, k, v, forward_batch)
q, output, _ = self.o_proj(attn_output)
k, return output
v,
forward_batch,
)
attn_output, _ = self.o_proj(context_layer)
return attn_output
class Glm4DecoderLayer(nn.Module): class Glm4DecoderLayer(nn.Module):
...@@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module): ...@@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config: Glm4Config,
layer_id: int, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__() super().__init__()
# Self attention. self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
head_dim = getattr(config, "head_dim", None)
partial_rotary_factor = getattr(config, "partial_rotary_factor", None)
dual_chunk_attention_config = getattr(
config, "dual_chunk_attention_config", None
)
self.self_attn = Glm4Attention( self.self_attn = Glm4Attention(
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix) hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=head_dim,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
dual_chunk_attention_config=dual_chunk_attention_config,
partial_rotary_factor=partial_rotary_factor,
prefix=add_prefix("self_attn", prefix),
) )
# MLP # MLP
...@@ -199,54 +291,125 @@ class Glm4Model(nn.Module): ...@@ -199,54 +291,125 @@ class Glm4Model(nn.Module):
config: Glm4Config, config: Glm4Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
decoder_layer_type: type[nn.Module] = Glm4DecoderLayer,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.padding_idx = config.pad_token_id
config.vocab_size, self.vocab_size = config.vocab_size
config.hidden_size, self.pp_group = get_pp_group()
quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix), if self.pp_group.is_first_rank:
) self.embed_tokens = VocabParallelEmbedding(
self.layers = make_layers( config.vocab_size,
config.hidden_size,
quant_config=quant_config,
enable_tp=not is_dp_attention_enabled(),
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Glm4DecoderLayer
decoder_layer_type = decoder_layer_type or Glm4DecoderLayer
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda idx, prefix: Glm4DecoderLayer( lambda idx, prefix: decoder_layer_type(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
), ),
prefix="model.layers", pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
) )
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # For EAGLE3 support
self.layers_to_capture = []
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens return self.embed_tokens
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: pp_proxy_tensors: Optional[PPProxyTensors] = None,
if input_embeds is None: ) -> Union[torch.Tensor, PPProxyTensors]:
hidden_states = self.embed_tokens(input_ids) if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else: else:
hidden_states = input_embeds assert pp_proxy_tensors is not None
residual = None hidden_states = pp_proxy_tensors["hidden_states"]
for layer in self.layers: residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
forward_batch, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states return hidden_states, aux_hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling factor attribute!"
)
class Glm4ForCausalLM(nn.Module): class Glm4ForCausalLM(nn.Module):
...@@ -255,21 +418,54 @@ class Glm4ForCausalLM(nn.Module): ...@@ -255,21 +418,54 @@ class Glm4ForCausalLM(nn.Module):
config: Glm4Config, config: Glm4Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ) -> None:
super().__init__() super().__init__()
self.config: Glm4Config = config self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Glm4Model(config, quant_config, add_prefix("model", prefix)) self.model = Glm4Model(
if config.tie_word_embeddings: config, quant_config=quant_config, prefix=add_prefix("model", prefix)
self.lm_head = self.model.embed_tokens )
# handle the lm head on different pp ranks
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
else: else:
self.lm_head = ParallelLMHead( # ranks other than the last rank will have a placeholder layer
config.vocab_size, self.lm_head = PPMissingLayer()
config.hidden_size,
quant_config=quant_config, # perform weight tying for PP
prefix="lm_head", if self.pp_group.world_size > 1 and config.tie_word_embeddings:
) if self.pp_group.is_first_rank:
self.pp_group.send(
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
)
else:
emb_token_weight = self.pp_group.recv(
size=(config.vocab_size, config.hidden_size),
dtype=next(self.model.parameters()).dtype,
src=self.pp_group.first_rank,
)
self.lm_head.weight.copy_(emb_token_weight)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# For EAGLE3 support
self.capture_aux_hidden_states = False
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -277,34 +473,138 @@ class Glm4ForCausalLM(nn.Module): ...@@ -277,34 +473,138 @@ class Glm4ForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(
return self.logits_processor( input_ids,
input_ids, hidden_states, self.lm_head, forward_batch positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return self.pooler(hidden_states, forward_batch)
else:
return hidden_states
@torch.no_grad()
def forward_split_prefill(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
split_interval: Tuple[int, int], # [start, end) 0-based
input_embeds: torch.Tensor = None,
):
start, end = split_interval
# embed
if start == 0:
if input_embeds is None:
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
else:
forward_batch.hidden_states = input_embeds
# decoder layer
for i in range(start, end):
layer = self.model.layers[i]
forward_batch.hidden_states, forward_batch.residual = layer(
positions,
forward_batch.hidden_states,
forward_batch,
forward_batch.residual,
)
if end == self.model.config.num_hidden_layers:
# norm
hidden_states, _ = self.model.norm(
forward_batch.hidden_states, forward_batch.residual
)
forward_batch.hidden_states = hidden_states
# logits process
result = self.logits_processor(
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
)
else:
result = None
return result
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, weight_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
(".gate_up_proj", ".gate_proj", 0),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.config.tie_word_embeddings and "lm_head.weight" in name: layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights = next(
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
)[1]
loaded_weight = embed_token_weights
else:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict.keys(): if name in params_dict.keys():
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
...@@ -312,7 +612,21 @@ class Glm4ForCausalLM(nn.Module): ...@@ -312,7 +612,21 @@ class Glm4ForCausalLM(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else: else:
raise KeyError(f"Parameter '{name}' not found in model.") logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
EntryClass = [Glm4ForCausalLM] EntryClass = [Glm4ForCausalLM]
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
import logging import logging
from functools import lru_cache, partial from functools import lru_cache
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import MultimodalDataItem from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4 import Glm4Model from sglang.srt.models.glm4 import Glm4Model
from sglang.srt.models.qwen2_5_vl import (
Qwen2_5_VisionBlock,
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
...@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module): ...@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features, input_size=in_features,
output_sizes=[hidden_features] * 2, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix), prefix=add_prefix("gate_up_proj", prefix),
...@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module): ...@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
return x return x
class Glm4vVisionBlock(Qwen2_5_VisionBlock): class Glm4vVisionBlock(nn.Module):
def __init__( def __init__(
self, self,
config: Glm4vVisionConfig, dim: int,
norm_layer: Optional[nn.Module] = None, intermediate_dim: int,
num_heads: int,
attn_implementation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-5,
) -> None: ) -> None:
super().__init__( super().__init__()
dim=config.hidden_size, self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
intermediate_dim=config.out_hidden_size, self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
num_heads=config.num_heads,
hidden_act=config.hidden_act, if attn_implementation is None:
norm_layer=norm_layer, softmax_in_single_precision = False
qkv_backend = None
flatten_batch = True
elif attn_implementation == "sdpa":
softmax_in_single_precision = False
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
qkv_backend = "triton_attn"
flatten_batch = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_3":
softmax_in_single_precision = False
qkv_backend = "fa3"
flatten_batch = True
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=add_prefix("attn", prefix),
num_dummy_heads=config.num_dummy_heads, num_dummy_heads=num_dummy_heads,
rms_norm_eps=config.rms_norm_eps,
) )
self.mlp = Glm4vVisionMLP( self.mlp = Glm4vVisionMLP(
config.hidden_size, dim,
config.out_hidden_size, intermediate_dim,
bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d = x.reshape(-1, H)
hidden_states = self.norm1(x2d).reshape(S, B, H)
# Attention expects [B, S, H]
hidden_states = rearrange(hidden_states, "s b h -> b s h")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s h -> s b h")
# norm2 with fused residual-add: also 2D
attn2d = attn.reshape(-1, H)
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
x_norm = x_norm_2d.reshape(S, B, H)
x_after_add = x_after_add_2d.reshape(S, B, H)
# MLP and final residual
mlp_out = self.mlp(x_norm)
x = x_after_add + mlp_out
return x
class Glm4vVisionPatchEmbed(nn.Module): class Glm4vVisionPatchEmbed(nn.Module):
def __init__( def __init__(
...@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module): ...@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
def __init__( def __init__(
self, self,
vision_config: Glm4vVisionConfig, vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module): ...@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
Glm4vVisionBlock( Glm4vVisionBlock(
config=vision_config, dim=self.hidden_size,
norm_layer=norm_layer, intermediate_dim=self.out_hidden_size,
num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix), prefix=add_prefix(f"blocks.{layer_idx}", prefix),
rms_norm_eps=vision_config.rms_norm_eps,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module): ...@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
return x return x
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): class Glm4vForConditionalGeneration(nn.Module):
def __init__( def __init__(
self, self,
config: Glm4vConfig, config: Glm4vConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) super().__init__()
self.config = config self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
self.visual = Glm4vVisionModel( self.visual = Glm4vVisionModel(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("visual", prefix), prefix=add_prefix("visual", prefix),
) )
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
...@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
# For EAGLE3 support # For EAGLE3 support
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat( pixel_values = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0 [item.feature.squeeze(0) for item in items], dim=0
...@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_embeds = torch.split(video_embeds, split_sizes) video_embeds = torch.split(video_embeds, split_sizes)
return torch.cat(video_embeds) return torch.cat(video_embeds)
def _update_hf_config(self): def get_input_embeddings(self):
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size""" return self.model.embed_tokens
tp_size = get_attention_tp_size()
num_heads = self.config.vision_config.num_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % tp_size != 0: @torch.no_grad()
num_dummy_heads = ( def forward(
(num_heads + tp_size - 1) // tp_size self,
) * tp_size - num_heads input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for GLM-4.1V.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
)
setattr(self.config.vision_config, "head_dim", head_dim) aux_hidden_states = None
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return self.pooler(hidden_states, forward_batch)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads""" """pad attn qkv weights for dummy heads"""
...@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "language_model." in name:
name = name.replace("language_model.", "")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
del self.lm_head.weight
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
EntryClass = [Glm4vForConditionalGeneration] EntryClass = [Glm4vForConditionalGeneration]
...@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
) )
self.visual = Glm4vVisionModel( self.visual = Glm4vVisionModel(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("visual", prefix), prefix=add_prefix("visual", prefix),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment