Unverified Commit 986537f1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V1] V1 FlashInfer Attention (#16684)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarAurick Qiao <qiao@aurick.net>
parent 21020752
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm import LLM, SamplingParams
from ...utils import fork_new_process_for_each_test
def test_cascade_attention(example_system_message, monkeypatch):
@fork_new_process_for_each_test
@pytest.mark.parametrize("attn_backend",
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
......
......@@ -1474,10 +1474,17 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No FlashInfer or XFormers so far.
# No XFormers so far.
V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
......
......@@ -213,6 +213,9 @@ class CudaPlatformBase(Platform):
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
......
......@@ -64,10 +64,6 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
@dataclass
class FlashAttentionMetadata:
......@@ -402,6 +398,9 @@ class FlashAttentionMetadataBuilder:
)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
class FlashAttentionImpl(AttentionImpl):
......
This diff is collapsed.
......@@ -251,10 +251,6 @@ class MLACommonBackend(AttentionBackend):
def get_supported_head_sizes() -> list[int]:
return [576]
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class MLACommonPrefillMetadata:
......@@ -574,6 +570,9 @@ class MLACommonMetadataBuilder(Generic[M]):
decode=decode_metadata,
)
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
......
......@@ -696,7 +696,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = self.attn_backend.use_cascade_attention(
use_cascade = self.attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment