Unverified Commit b2c85669 authored by Zheng Wengang's avatar Zheng Wengang Committed by GitHub
Browse files

[BugFix][Qwen3-VL]: fix cu_seqlens in qwen3-vl (#11458)

parent 32803fb2
......@@ -452,13 +452,15 @@ class Qwen3_VisionTransformer(nn.Module):
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0)
cu_seqlens = torch.cat(
[
torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
cu_seqlens.to(torch.int32),
]
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
x = x.unsqueeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment