Unverified Commit 986537f1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V1] V1 FlashInfer Attention (#16684)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarAurick Qiao <qiao@aurick.net>
parent 21020752
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from ...utils import fork_new_process_for_each_test
def test_cascade_attention(example_system_message, monkeypatch): @fork_new_process_for_each_test
@pytest.mark.parametrize("attn_backend",
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct") llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
......
...@@ -1474,10 +1474,17 @@ class EngineArgs: ...@@ -1474,10 +1474,17 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No FlashInfer or XFormers so far. # No XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", "FLASH_ATTN_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" "FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
......
...@@ -213,6 +213,9 @@ class CudaPlatformBase(Platform): ...@@ -213,6 +213,9 @@ class CudaPlatformBase(Platform):
return ("vllm.attention.backends." return ("vllm.attention.backends."
"flashmla.FlashMLABackend") "flashmla.FlashMLABackend")
if use_v1: if use_v1:
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
......
...@@ -64,10 +64,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -64,10 +64,6 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
@dataclass @dataclass
class FlashAttentionMetadata: class FlashAttentionMetadata:
...@@ -402,6 +398,9 @@ class FlashAttentionMetadataBuilder: ...@@ -402,6 +398,9 @@ class FlashAttentionMetadataBuilder:
) )
return attn_metadata return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
......
This diff is collapsed.
...@@ -251,10 +251,6 @@ class MLACommonBackend(AttentionBackend): ...@@ -251,10 +251,6 @@ class MLACommonBackend(AttentionBackend):
def get_supported_head_sizes() -> list[int]: def get_supported_head_sizes() -> list[int]:
return [576] return [576]
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass @dataclass
class MLACommonPrefillMetadata: class MLACommonPrefillMetadata:
...@@ -574,6 +570,9 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -574,6 +570,9 @@ class MLACommonMetadataBuilder(Generic[M]):
decode=decode_metadata, decode=decode_metadata,
) )
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
""" """
......
...@@ -696,7 +696,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -696,7 +696,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# common_prefix_len should be a multiple of the block size. # common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size * common_prefix_len = (common_prefix_len // self.block_size *
self.block_size) self.block_size)
use_cascade = self.attn_backend.use_cascade_attention( use_cascade = self.attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens, query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads, num_query_heads=self.num_query_heads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment