Unverified Commit 7041de43 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (#4628)


Co-authored-by: default avatarLiuXiaoxuanPKU &lt;llilyliupku@gmail.com&gt;, bong-furiosa <bongwon.jang@furiosa.ai>
parent 6a62cb82
......@@ -211,3 +211,6 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
\ No newline at end of file
......@@ -19,4 +19,4 @@ sentence-transformers # required for embedding
aiohttp
# quantization
bitsandbytes==0.42.0
bitsandbytes==0.42.0
\ No newline at end of file
......@@ -2,7 +2,6 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref
import pytest
......@@ -13,7 +12,6 @@ MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
def test_vllm_gc_ed():
......@@ -39,10 +37,6 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
......
......@@ -21,7 +21,6 @@ MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2,
......@@ -39,16 +38,12 @@ def test_models(
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
enforce_eager = backend_by_env_var == "FLASHINFER"
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
......
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import flashinfer
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
......@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
max_prefill_seq_len: int
use_cuda_graph: bool = False
use_cuda_graph: bool = True
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
......@@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
def __post_init__(self):
# Refer to
......@@ -109,13 +113,35 @@ class FlashInferMetadata(AttentionMetadata):
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
def begin_forward(self):
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return
assert self.prefill_wrapper is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
assert self.decode_wrapper is not None
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
......@@ -133,8 +159,9 @@ class FlashInferMetadata(AttentionMetadata):
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
......@@ -168,6 +195,7 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -217,10 +245,14 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype,
)
query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
q=query,
k=key,
......@@ -235,13 +267,14 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
......
......@@ -77,8 +77,9 @@ def get_attn_backend(
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
" please avoid using Flashinfer as the"
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment