Unverified Commit 7041de43 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (#4628)


Co-authored-by: default avatarLiuXiaoxuanPKU &lt;llilyliupku@gmail.com&gt;, bong-furiosa <bongwon.jang@furiosa.ai>
parent 6a62cb82
...@@ -211,3 +211,6 @@ steps: ...@@ -211,3 +211,6 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
\ No newline at end of file
...@@ -19,4 +19,4 @@ sentence-transformers # required for embedding ...@@ -19,4 +19,4 @@ sentence-transformers # required for embedding
aiohttp aiohttp
# quantization # quantization
bitsandbytes==0.42.0 bitsandbytes==0.42.0
\ No newline at end of file
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`. Run `pytest tests/basic_correctness/test_basic_correctness.py`.
""" """
import os
import weakref import weakref
import pytest import pytest
...@@ -13,7 +12,6 @@ MODELS = [ ...@@ -13,7 +12,6 @@ MODELS = [
"facebook/opt-125m", "facebook/opt-125m",
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
] ]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
def test_vllm_gc_ed(): def test_vllm_gc_ed():
...@@ -39,10 +37,6 @@ def test_models( ...@@ -39,10 +37,6 @@ def test_models(
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
) -> None: ) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
......
...@@ -21,7 +21,6 @@ MODELS = [ ...@@ -21,7 +21,6 @@ MODELS = [
os.environ["TEST_DIST_MODEL"], os.environ["TEST_DIST_MODEL"],
] ]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
...@@ -39,16 +38,12 @@ def test_models( ...@@ -39,16 +38,12 @@ def test_models(
) -> None: ) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
enforce_eager = backend_by_env_var == "FLASHINFER"
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend distributed_executor_backend=distributed_executor_backend
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import Any, Dict, List, Optional, Set, Tuple, Type
import flashinfer try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
import torch import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only. # requests only.
max_prefill_seq_len: int max_prefill_seq_len: int
use_cuda_graph: bool = False use_cuda_graph: bool = True
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still # Metadata for the prefill stage
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr: # An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8] # request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7] # request 2, page indices [1, 6, 7]
...@@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None page_size: Optional[int] = None
# The data type of the paged kv cache # The data type of the paged kv cache
data_type: torch.dtype = None data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
def __post_init__(self): def __post_init__(self):
# Refer to # Refer to
...@@ -109,13 +113,35 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -109,13 +113,35 @@ class FlashInferMetadata(AttentionMetadata):
f"Only {supported_head_sizes} are supported for head_dim,", f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.") f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata, def begin_forward(self):
# which will also call post_init by default, here we want to skip the if self.num_prefill_tokens > 0:
# post_init if it's the prefill phase. if self.paged_kv_indices is None:
if self.num_prefills == 0: return
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( assert self.prefill_wrapper is not None
self.workspace_buffer, "NHD") assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.begin_forward( self.decode_wrapper.begin_forward(
self.paged_kv_indptr, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_indices,
...@@ -133,8 +159,9 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -133,8 +159,9 @@ class FlashInferMetadata(AttentionMetadata):
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if skip_fields is None: if skip_fields is None:
skip_fields = set() skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be # We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled. # broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper') skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields) return super().asdict_zerocopy(skip_fields)
...@@ -168,6 +195,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -168,6 +195,7 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[List[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -217,10 +245,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -217,10 +245,14 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype, self.kv_cache_dtype,
) )
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # We will use flash attention for prefill
assert prefill_meta.block_tables is not None # when kv_cache is not provided.
if kv_cache is None or prefill_meta.block_tables.numel() == 0: # This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
...@@ -235,13 +267,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -235,13 +267,14 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
) )
else: else:
raise NotImplementedError( assert prefill_meta is not None
"Prefix caching is not supported with flashinfer yet.") assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
else: else:
assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward( output = attn_metadata.decode_metadata.decode_wrapper.forward(
query, query,
kv_cache, kv_cache,
......
...@@ -77,8 +77,9 @@ def get_attn_backend( ...@@ -77,8 +77,9 @@ def get_attn_backend(
return IpexAttnBackend return IpexAttnBackend
elif backend == _Backend.FLASHINFER: elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.") logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. " logger.warning(("Flashinfer will be stuck on llma-2-7b,"
"Please make sure --enforce-eager is set.") " please avoid using Flashinfer as the"
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend return FlashInferBackend
elif backend == _Backend.PALLAS: elif backend == _Backend.PALLAS:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment