Unverified Commit a88b006e authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

GLM-4-0414 and GLM-4.1V Code Refactor (#12117)

parent ce112c07
...@@ -1070,6 +1070,7 @@ def _triton_mrope_forward( ...@@ -1070,6 +1070,7 @@ def _triton_mrope_forward(
mrope_section_h: tl.constexpr, mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr, mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr, is_interleaved: tl.constexpr,
is_neox_style: tl.constexpr,
): ):
# Adapted from # Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
...@@ -1124,51 +1125,99 @@ def _triton_mrope_forward( ...@@ -1124,51 +1125,99 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately # program instance (i.e. for the current token) separately
# #################################################################### # ####################################################################
# left half of the head # left half of the head
first_half_q_offsets = ( if is_neox_style:
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] first_half_q_offsets = (
) tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_half_k_offsets = ( )
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] first_half_k_offsets = (
) tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( )
tl.arange(0, pad_hd // 2)[None, :] < rd // 2 first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
) tl.arange(0, pad_hd // 2)[None, :] < rd // 2
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( )
tl.arange(0, pad_hd // 2)[None, :] < rd // 2 first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
) tl.arange(0, pad_hd // 2)[None, :] < rd // 2
)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
# right half of the head # right half of the head
second_half_q_offsets = first_half_q_offsets + (rd // 2) second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2) second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask second_q_mask = first_q_mask
second_k_mask = first_k_mask second_k_mask = first_k_mask
q_tile_2 = tl.load(
q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
).to(sin_row.dtype)
k_tile_2 = tl.load(
k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
).to(sin_row.dtype)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
# we use the same cos_row and sin_row for both halves
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
else:
base_q = tl.arange(0, pad_n_qh)[:, None] * hd
base_k = tl.arange(0, pad_n_kh)[:, None] * hd
even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
odd_idx = even_idx + 1
even_q_offsets = base_q + even_idx
odd_q_offsets = base_q + odd_idx
even_k_offsets = base_k + even_idx
odd_k_offsets = base_k + odd_idx
idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
even_q_mask = qn_mask & idx_mask
odd_q_mask = qn_mask & idx_mask
even_k_mask = kn_mask & idx_mask
odd_k_mask = kn_mask & idx_mask
q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
sin_row.dtype
)
k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
sin_row.dtype
)
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
sin_row.dtype sin_row.dtype
) )
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] # y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
# Since cos and sin are now half-size, # NeoX-style rotary embedding:
# we use the same cos_row and sin_row for both halves # Each (even, odd) channel pair forms one rotation arm.
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row # cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
def triton_mrope( def triton_mrope(
...@@ -1180,6 +1229,7 @@ def triton_mrope( ...@@ -1180,6 +1229,7 @@ def triton_mrope(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
mrope_interleaved: bool, mrope_interleaved: bool,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""The mrope triton kernel. """The mrope triton kernel.
...@@ -1230,6 +1280,7 @@ def triton_mrope( ...@@ -1230,6 +1280,7 @@ def triton_mrope(
mrope_section[1], mrope_section[1],
mrope_section[2], mrope_section[2],
mrope_interleaved, mrope_interleaved,
is_neox_style,
) )
return q, k return q, k
...@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1400,6 +1451,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.head_size, self.head_size,
self.rotary_dim, self.rotary_dim,
self.mrope_interleaved, self.mrope_interleaved,
self.is_neox_style,
) )
return q.reshape(query_shape), k.reshape(key_shape) return q.reshape(query_shape), k.reshape(key_shape)
......
This diff is collapsed.
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Modeling from:
# ./llama.py and
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
import logging import logging
from functools import lru_cache, partial from functools import lru_cache
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import MultimodalDataItem from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4 import Glm4Model from sglang.srt.models.glm4 import Glm4Model
from sglang.srt.models.qwen2_5_vl import (
Qwen2_5_VisionBlock,
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
...@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module): ...@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features, input_size=in_features,
output_sizes=[hidden_features] * 2, output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix), prefix=add_prefix("gate_up_proj", prefix),
...@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module): ...@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
return x return x
class Glm4vVisionBlock(Qwen2_5_VisionBlock): class Glm4vVisionBlock(nn.Module):
def __init__( def __init__(
self, self,
config: Glm4vVisionConfig, dim: int,
norm_layer: Optional[nn.Module] = None, intermediate_dim: int,
num_heads: int,
attn_implementation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
num_dummy_heads: int = 0,
rms_norm_eps: float = 1e-5,
) -> None: ) -> None:
super().__init__( super().__init__()
dim=config.hidden_size, self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
intermediate_dim=config.out_hidden_size, self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
num_heads=config.num_heads,
hidden_act=config.hidden_act, if attn_implementation is None:
norm_layer=norm_layer, softmax_in_single_precision = False
qkv_backend = None
flatten_batch = True
elif attn_implementation == "sdpa":
softmax_in_single_precision = False
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_2":
softmax_in_single_precision = False
qkv_backend = "triton_attn"
flatten_batch = True
elif attn_implementation == "eager":
softmax_in_single_precision = True
qkv_backend = "sdpa"
flatten_batch = True
elif attn_implementation == "flash_attention_3":
softmax_in_single_precision = False
qkv_backend = "fa3"
flatten_batch = True
self.attn = VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=True,
rotary_embed="normal",
proj_bias=True,
qkv_backend=qkv_backend,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=add_prefix("attn", prefix),
num_dummy_heads=config.num_dummy_heads, num_dummy_heads=num_dummy_heads,
rms_norm_eps=config.rms_norm_eps,
) )
self.mlp = Glm4vVisionMLP( self.mlp = Glm4vVisionMLP(
config.hidden_size, dim,
config.out_hidden_size, intermediate_dim,
bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
S, B, H = x.shape
# norm1: flatten to 2D -> [S*B, H], then reshape back
x2d = x.reshape(-1, H)
hidden_states = self.norm1(x2d).reshape(S, B, H)
# Attention expects [B, S, H]
hidden_states = rearrange(hidden_states, "s b h -> b s h")
attn = self.attn(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
)
attn = rearrange(attn, "b s h -> s b h")
# norm2 with fused residual-add: also 2D
attn2d = attn.reshape(-1, H)
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
x_norm = x_norm_2d.reshape(S, B, H)
x_after_add = x_after_add_2d.reshape(S, B, H)
# MLP and final residual
mlp_out = self.mlp(x_norm)
x = x_after_add + mlp_out
return x
class Glm4vVisionPatchEmbed(nn.Module): class Glm4vVisionPatchEmbed(nn.Module):
def __init__( def __init__(
...@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module): ...@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
def __init__( def __init__(
self, self,
vision_config: Glm4vVisionConfig, vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module): ...@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
Glm4vVisionBlock( Glm4vVisionBlock(
config=vision_config, dim=self.hidden_size,
norm_layer=norm_layer, intermediate_dim=self.out_hidden_size,
num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix), prefix=add_prefix(f"blocks.{layer_idx}", prefix),
rms_norm_eps=vision_config.rms_norm_eps,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
...@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module): ...@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
return x return x
class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): class Glm4vForConditionalGeneration(nn.Module):
def __init__( def __init__(
self, self,
config: Glm4vConfig, config: Glm4vConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) super().__init__()
self.config = config self.config = config
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
self.visual = Glm4vVisionModel( self.visual = Glm4vVisionModel(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("visual", prefix), prefix=add_prefix("visual", prefix),
) )
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.model = Glm4Model(
config,
quant_config=quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
...@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
# For EAGLE3 support # For EAGLE3 support
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.cat( pixel_values = torch.cat(
[item.feature.squeeze(0) for item in items], dim=0 [item.feature.squeeze(0) for item in items], dim=0
...@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
video_embeds = torch.split(video_embeds, split_sizes) video_embeds = torch.split(video_embeds, split_sizes)
return torch.cat(video_embeds) return torch.cat(video_embeds)
def _update_hf_config(self): def get_input_embeddings(self):
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size""" return self.model.embed_tokens
tp_size = get_attention_tp_size()
num_heads = self.config.vision_config.num_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % tp_size != 0: @torch.no_grad()
num_dummy_heads = ( def forward(
(num_heads + tp_size - 1) // tp_size self,
) * tp_size - num_heads input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
get_embedding: bool = False,
):
"""Run forward pass for GLM-4.1V.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
)
setattr(self.config.vision_config, "head_dim", head_dim) aux_hidden_states = None
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return self.pooler(hidden_states, forward_batch)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads""" """pad attn qkv weights for dummy heads"""
...@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "language_model." in name:
name = name.replace("language_model.", "")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "language_model" in name:
name = name.replace(r"model.language_model.", r"model.")
if "model.visual." in name:
name = name.replace("model.visual.", "visual.")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): ...@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
self.model.embed_tokens.weight = embed
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
del self.lm_head.weight
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
EntryClass = [Glm4vForConditionalGeneration] EntryClass = [Glm4vForConditionalGeneration]
...@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -53,7 +53,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
) )
self.visual = Glm4vVisionModel( self.visual = Glm4vVisionModel(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("visual", prefix), prefix=add_prefix("visual", prefix),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment