Unverified Commit f0d3658c authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[MM][OOT] Support CPU `seq_lens` for OOT MMEncoderAttention kernels (#36605)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 57431d82
...@@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer( ...@@ -297,11 +297,10 @@ def test_mha_attn_varlen_forward_flashinfer(
hidden_size = num_heads * head_size hidden_size = num_heads * head_size
tp_size = 1 tp_size = 1
sequence_lengths_np = MMEncoderAttention.maybe_compute_sequence_lengths( sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
AttentionBackendEnum.FLASHINFER, cu_seqlens_np AttentionBackendEnum.FLASHINFER,
) cu_seqlens_np,
sequence_lengths = torch.from_numpy(sequence_lengths_np).to( device,
device, dtype=torch.int32, non_blocking=True
) )
max_seqlen_val = MMEncoderAttention.compute_max_seqlen( max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
...@@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer( ...@@ -309,14 +308,12 @@ def test_mha_attn_varlen_forward_flashinfer(
) )
max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32) max_seqlen = torch.tensor(max_seqlen_val, device=device, dtype=torch.int32)
cu_seqlens_np = MMEncoderAttention.maybe_recompute_cu_seqlens( cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER,
cu_seqlens_np, cu_seqlens_np,
hidden_size, hidden_size,
tp_size, tp_size,
) device,
cu_seqlens = torch.from_numpy(cu_seqlens_np).to(
device, dtype=torch.int32, non_blocking=True
) )
scale = 1.0 / head_size**0.5 scale = 1.0 / head_size**0.5
......
...@@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} ...@@ -22,6 +22,12 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {} op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
def get_oot_class_by_name(class_name: str) -> type | None:
if class_name in op_registry_oot:
return op_registry_oot[class_name]
return None
class PluggableLayer(nn.Module): class PluggableLayer(nn.Module):
""" """
Base class for pluggable layers. Base class for pluggable layers.
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp, get_oot_class_by_name
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
...@@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp): ...@@ -119,17 +119,25 @@ class MMEncoderAttention(CustomOp):
return max_seqlen return max_seqlen
@classmethod @classmethod
def maybe_compute_sequence_lengths( def maybe_compute_seq_lens(
cls, cls,
attn_backend: AttentionBackendEnum, attn_backend: AttentionBackendEnum,
cu_seqlens: np.ndarray, cu_seqlens: np.ndarray,
) -> np.ndarray | None: device: torch.device,
) -> torch.Tensor | None:
if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_compute_seq_lens(attn_backend, cu_seqlens, device) # type: ignore[attr-defined]
if attn_backend != AttentionBackendEnum.FLASHINFER: if attn_backend != AttentionBackendEnum.FLASHINFER:
return None return None
sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
sequence_lengths = add_padding_to_seqlens( sequence_lengths = add_padding_to_seqlens(
sequence_lengths, len(sequence_lengths), 0 sequence_lengths, len(sequence_lengths), 0
) )
sequence_lengths = torch.from_numpy(sequence_lengths).to(
device, non_blocking=True
)
return sequence_lengths return sequence_lengths
@classmethod @classmethod
...@@ -139,10 +147,14 @@ class MMEncoderAttention(CustomOp): ...@@ -139,10 +147,14 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: np.ndarray, cu_seqlens: np.ndarray,
hidden_size: int, hidden_size: int,
tp_size: int, tp_size: int,
) -> np.ndarray: device: torch.device,
if attn_backend != AttentionBackendEnum.FLASHINFER: ) -> torch.Tensor:
return cu_seqlens if (oot_class := get_oot_class_by_name(cls.__name__)) is not None:
return oot_class.maybe_recompute_cu_seqlens( # type: ignore[attr-defined]
attn_backend, cu_seqlens, hidden_size, tp_size, device
)
if attn_backend == AttentionBackendEnum.FLASHINFER:
batch_size = len(cu_seqlens) - 1 batch_size = len(cu_seqlens) - 1
scale = hidden_size // tp_size scale = hidden_size // tp_size
cu_seqlens = cu_seqlens * scale cu_seqlens = cu_seqlens * scale
...@@ -156,7 +168,10 @@ class MMEncoderAttention(CustomOp): ...@@ -156,7 +168,10 @@ class MMEncoderAttention(CustomOp):
cu_seqlens_v = add_padding_to_seqlens( cu_seqlens_v = add_padding_to_seqlens(
cu_seqlens_v, batch_size, cu_seqlens_v[-1] cu_seqlens_v, batch_size, cu_seqlens_v[-1]
) )
return np.concatenate([cu_seqlens_qko, cu_seqlens_v]) cu_seqlens = np.concatenate([cu_seqlens_qko, cu_seqlens_v])
cu_seqlens = torch.from_numpy(cu_seqlens).to(device, non_blocking=True)
return cu_seqlens
def __init__( def __init__(
self, self,
......
...@@ -983,12 +983,10 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -983,12 +983,10 @@ class Qwen3Omni_VisionTransformer(nn.Module):
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0] grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32) ).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np]) cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens_np self.attn_backend,
) cu_seqlens_np,
if sequence_lengths is not None: self.device,
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
) )
hidden_states_list = [] hidden_states_list = []
......
...@@ -550,12 +550,8 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -550,12 +550,8 @@ class Qwen3_VisionTransformer(nn.Module):
axis=0, dtype=np.int32 axis=0, dtype=np.int32
) )
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_sequence_lengths( sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens self.attn_backend, cu_seqlens, self.device
)
if sequence_lengths is not None:
sequence_lengths = torch.from_numpy(sequence_lengths).to(
self.device, non_blocking=True
) )
max_seqlen = torch.tensor( max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens), MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
...@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -567,8 +563,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens, cu_seqlens,
self.hidden_size, self.hidden_size,
self.tp_size, self.tp_size,
self.device,
) )
cu_seqlens = torch.from_numpy(cu_seqlens).to(self.device, non_blocking=True)
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = [] deepstack_feature_lists = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment