Commit 0144aef4 authored by Li Xiaohui's avatar Li Xiaohui
Browse files

Replace VL model implementations with updated version

parent 3a716536
......@@ -280,10 +280,9 @@ class Glm4vVisionAttention(nn.Module):
f"GLM-4V does not support {self.attn_backend} backend now.")
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
if qkv.dim() == 2:
qkv = qkv.unsqueeze(1) # dim加上batch维度
# [s, b, 3 * head * head_dim]
if qkv.dim() == 2:
qkv = qkv.unsqueeze(1) # 在 dim=1 加 batch 维度
seq_len, bs, _ = qkv.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
......@@ -427,9 +426,6 @@ class Glm4vVisionBlock(nn.Module):
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
# -------------------------
# 1) Attention
# -------------------------
normed_x = self.norm1(x)
x_attn = self.attn(
normed_x,
......@@ -438,22 +434,12 @@ class Glm4vVisionBlock(nn.Module):
max_seqlen=max_seqlen,
seqlens=seqlens,
)
# 保证 attn 输出为 3D tensor
if x_attn.dim() == 2:
x_attn = x_attn.unsqueeze(1)
elif x_attn.dim() == 1:
x_attn = x_attn.unsqueeze(1).unsqueeze(2)
assert x_attn.dim() == 3, f"x_attn must be 3D, got {x_attn.shape}"
# -------------------------
# 2) norm2 + residual
# -------------------------
x_fused_norm, residual = self.norm2(x, residual=x_attn)
# -------------------------
# 3) MLP 前形状检查(核心)
# ------------------------
if x_fused_norm.dim() == 3 and x_fused_norm.shape[1] == 1:
mlp_in = x_fused_norm.squeeze(1)
restore_3d = True
......@@ -462,18 +448,9 @@ class Glm4vVisionBlock(nn.Module):
restore_3d = False
else:
raise RuntimeError(f"Unexpected x_fused_norm shape {x_fused_norm.shape}, expect (N,D) or (N,1,D)")
# -------------------------
# 4) MLP
# ------------------------
out = self.mlp(mlp_in)
# MLP 可能返回 (N,D),恢复回三维
if restore_3d:
out = out.unsqueeze(1)
# -------------------------
# 5) residual + mlp_out
# -------------------------
assert out.shape == residual.shape, \
f"residual {residual.shape} vs mlp_out {out.shape} mismatch"
x = residual + out
......
......@@ -925,6 +925,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
return loaded_params
class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment