Unverified Commit 3f1b0373 authored by TJian's avatar TJian Committed by GitHub
Browse files

[ROCm] [Bugfix] `compute_attn_mask_seqlen` for qwen3 omni (#29974)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 9aa33a74
...@@ -494,7 +494,10 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -494,7 +494,10 @@ class Qwen3Omni_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment