Commit c7c3415a authored by zhuwenwen's avatar zhuwenwen
Browse files

适配glm4_1v量化模型

parent 4ec64732
...@@ -312,6 +312,8 @@ class Glm4vVisionAttention(nn.Module): ...@@ -312,6 +312,8 @@ class Glm4vVisionAttention(nn.Module):
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
if qkv.dim() == 2:
qkv = qkv.unsqueeze(1) # 在 dim=1 加 batch 维度
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
...@@ -413,8 +415,30 @@ class Glm4vVisionBlock(nn.Module): ...@@ -413,8 +415,30 @@ class Glm4vVisionBlock(nn.Module):
rotary_pos_emb_sin=rotary_pos_emb_sin, rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
) )
if x_attn.dim() == 2:
x_attn = x_attn.unsqueeze(1)
elif x_attn.dim() == 1:
x_attn = x_attn.unsqueeze(1).unsqueeze(2)
assert x_attn.dim() == 3, f"x_attn must be 3D, got {x_attn.shape}"
x_fused_norm, residual = self.norm2(x, residual=x_attn) x_fused_norm, residual = self.norm2(x, residual=x_attn)
x = residual + self.mlp(x_fused_norm) if x_fused_norm.dim() == 3 and x_fused_norm.shape[1] == 1:
mlp_in = x_fused_norm.squeeze(1)
restore_3d = True
elif x_fused_norm.dim() == 2:
mlp_in = x_fused_norm
restore_3d = False
else:
raise RuntimeError(f"Unexpected x_fused_norm shape {x_fused_norm.shape}, expect (N,D) or (N,1,D)")
out = self.mlp(mlp_in)
if restore_3d:
out = out.unsqueeze(1)
assert out.shape == residual.shape, \
f"residual {residual.shape} vs mlp_out {out.shape} mismatch"
x = residual + out
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment