Commit e03b1b33 authored by zhuwenwen's avatar zhuwenwen
Browse files

适配glm4_1v量化模型

parent fea96436
......@@ -281,6 +281,8 @@ class Glm4vVisionAttention(nn.Module):
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
if qkv.dim() == 2:
qkv = qkv.unsqueeze(1) # 在 dim=1 加 batch 维度
seq_len, bs, _ = qkv.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
......@@ -431,8 +433,30 @@ class Glm4vVisionBlock(nn.Module):
max_seqlen=max_seqlen,
seqlens=seqlens,
)
if x_attn.dim() == 2:
x_attn = x_attn.unsqueeze(1)
elif x_attn.dim() == 1:
x_attn = x_attn.unsqueeze(1).unsqueeze(2)
assert x_attn.dim() == 3, f"x_attn must be 3D, got {x_attn.shape}"
x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm)
if x_fused_norm.dim() == 3 and x_fused_norm.shape[1] == 1:
mlp_in = x_fused_norm.squeeze(1)
restore_3d = True
elif x_fused_norm.dim() == 2:
mlp_in = x_fused_norm
restore_3d = False
else:
raise RuntimeError(f"Unexpected x_fused_norm shape {x_fused_norm.shape}, expect (N,D) or (N,1,D)")
out = self.mlp(mlp_in)
if restore_3d:
out = out.unsqueeze(1)
assert out.shape == residual.shape, \
f"residual {residual.shape} vs mlp_out {out.shape} mismatch"
x = residual + out
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment