Unverified Commit 2c1c7dfb authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Models][Qwen] Replace `pad` with `cat` for better performance (#26486)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent e246ad6f
......@@ -680,7 +680,7 @@ class DotsVisionTransformer(nn.Module):
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
for blk in self.blocks:
......
......@@ -574,11 +574,12 @@ class Ernie4_5_VisionTransformer(nn.Module):
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
zeros = cu_seqlens.new_zeros(1)
if num_pad > 0:
cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0)
cu_seqlens = torch.cat([zeros, cu_seqlens, zeros])
cu_seqlens[-1] = cu_seqlens[-2] + num_pad
else:
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = torch.cat([zeros, cu_seqlens])
# add batch size
if hidden_states.ndim == 2:
......
......@@ -539,7 +539,7 @@ class Qwen3_VisionTransformer(nn.Module):
dim=0,
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
......
......@@ -592,7 +592,7 @@ class Siglip2Encoder(nn.Module):
# for more information
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
reverse_indices = torch.argsort(window_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment