Unverified Commit 56a63717 authored by Andrii Skliar's avatar Andrii Skliar Committed by GitHub
Browse files

[Update] Use FlashInfer fast_decode_plan directly instead of replication (#34687)


Signed-off-by: default avatarAndrii <askliar@nvidia.com>
Co-authored-by: default avatarAndrii <askliar@nvidia.com>
parent 62830211
...@@ -84,6 +84,209 @@ def ref_paged_attn( ...@@ -84,6 +84,209 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
def _make_paged_kv_metadata(
kv_lens: list[int],
block_size: int,
num_blocks: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build paged-KV metadata tensors for fast_plan_decode tests.
Returns:
kv_indptr – CPU int32, shape [num_seqs + 1]
kv_indices – CUDA int32, shape [total_blocks]
kv_last_page_lens – CPU int32, shape [num_seqs]
block_tables – CUDA int32, shape [num_seqs, max_blocks_per_seq]
"""
num_seqs = len(kv_lens)
max_blocks = (max(kv_lens) + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_blocks), dtype=torch.int32, device="cuda"
)
indptr_list = [0]
indices_list: list[int] = []
last_lens_list: list[int] = []
for i, seq_len in enumerate(kv_lens):
n = (seq_len + block_size - 1) // block_size
indices_list.extend(block_tables[i, :n].cpu().tolist())
indptr_list.append(indptr_list[-1] + n)
last_lens_list.append(seq_len % block_size or block_size)
return (
torch.tensor(indptr_list, dtype=torch.int32, device="cpu"),
torch.tensor(indices_list, dtype=torch.int32, device="cuda"),
torch.tensor(last_lens_list, dtype=torch.int32, device="cpu"),
block_tables,
)
def _make_cg_decode_wrapper(
num_seqs: int,
kv_indices_buffer: torch.Tensor,
workspace_buffer: torch.Tensor,
use_tensor_cores: bool = True,
) -> "flashinfer.BatchDecodeWithPagedKVCacheWrapper":
"""Create a cudagraph-enabled BatchDecodeWithPagedKVCacheWrapper.
*kv_indices_buffer* is shared with the caller so that fast_plan_decode
can avoid the device-to-device index copy on subsequent (cudagraph) calls.
"""
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
"NHD",
use_cuda_graph=True,
paged_kv_indptr_buffer=torch.zeros(
num_seqs + 1, dtype=torch.int32, device="cuda"
),
paged_kv_indices_buffer=kv_indices_buffer,
paged_kv_last_page_len_buffer=torch.zeros(
num_seqs, dtype=torch.int32, device="cuda"
),
use_tensor_cores=use_tensor_cores,
)
def test_fast_decode_plan_importable() -> None:
"""fast_decode_plan must be importable from flashinfer.decode.
This is a forward-compatibility smoke test: if FlashInfer reorganises its
public API the import will fail before any other test does.
"""
from flashinfer.decode import fast_decode_plan # noqa: F401
assert callable(fast_decode_plan)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_fast_plan_decode_warmup_uses_full_plan(dtype: torch.dtype) -> None:
"""On the first call fast_plan_decode must route through self.plan() and
flip vllm_first_call to False on the wrapper object."""
from unittest.mock import patch
from vllm.v1.attention.backends.flashinfer import fast_plan_decode
torch.set_default_device("cuda")
set_random_seed(0)
kv_lens = [128, 64]
block_size = 16
num_seqs = len(kv_lens)
num_query_heads, num_kv_heads = 8, 2
head_size = 128
kv_indptr, kv_indices, kv_last_page_lens, _ = _make_paged_kv_metadata(
kv_lens, block_size, NUM_BLOCKS
)
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = _make_cg_decode_wrapper(num_seqs, kv_indices.clone(), workspace)
assert getattr(wrapper, "vllm_first_call", True) is True
with patch.object(wrapper, "plan", wraps=wrapper.plan) as mock_plan:
fast_plan_decode(
wrapper,
indptr_cpu=kv_indptr,
indices=kv_indices,
last_page_len_cpu=kv_last_page_lens,
num_qo_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=block_size,
q_data_type=dtype,
kv_data_type=dtype,
)
mock_plan.assert_called_once()
assert wrapper.vllm_first_call is False, (
"vllm_first_call should be False after the first fast_plan_decode call"
)
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_fast_plan_decode_matches_full_plan(
kv_lens: list[int],
num_heads: tuple[int, int],
head_size: int,
block_size: int,
dtype: torch.dtype,
) -> None:
"""fast_plan_decode's cudagraph path (delegating to FlashInfer's
fast_decode_plan) must produce attention output numerically identical to
a standard plan() call.
Both the warmup call (self.plan) and the subsequent fast call
(fast_decode_plan) are verified against the same reference.
"""
from vllm.v1.attention.backends.flashinfer import fast_plan_decode
torch.set_default_device("cuda")
set_random_seed(0)
num_seqs = len(kv_lens)
num_query_heads, num_kv_heads = num_heads
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_value_cache = torch.randn(
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
)
kv_indptr, kv_indices, kv_last_page_lens, _ = _make_paged_kv_metadata(
kv_lens, block_size, NUM_BLOCKS
)
# Reference output via the standard plan()
workspace_ref = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
ref_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_ref, "NHD", use_tensor_cores=True
)
ref_wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
)
ref_output = ref_wrapper.run(query, key_value_cache)
# CUDAGraph wrapper exercised through fast_plan_decode
kv_indices_buf = kv_indices.clone()
workspace_cg = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
cg_wrapper = _make_cg_decode_wrapper(num_seqs, kv_indices_buf, workspace_cg)
plan_kwargs: dict = dict(
indptr_cpu=kv_indptr,
indices=kv_indices_buf,
last_page_len_cpu=kv_last_page_lens,
num_qo_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=block_size,
q_data_type=dtype,
kv_data_type=dtype,
)
# First call – warmup path (routes through self.plan)
fast_plan_decode(cg_wrapper, **plan_kwargs)
warmup_output = cg_wrapper.run(query, key_value_cache)
torch.testing.assert_close(warmup_output, ref_output, atol=1e-2, rtol=1e-2)
# Second call – fast path (routes through fast_decode_plan from FlashInfer)
fast_plan_decode(cg_wrapper, **plan_kwargs)
fast_output = cg_wrapper.run(query, key_value_cache)
torch.testing.assert_close(fast_output, ref_output, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
......
...@@ -13,7 +13,7 @@ from flashinfer import ( ...@@ -13,7 +13,7 @@ from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper, MultiLevelCascadeAttentionWrapper,
) )
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor from flashinfer.utils import FP4Tensor
from typing_extensions import override from typing_extensions import override
...@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper: ...@@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper:
): ):
"""Plan the prefill operation with given parameters.""" """Plan the prefill operation with given parameters."""
self._context.plan( self._context.plan(
qo_indptr_cpu, qo_indptr=qo_indptr_cpu,
paged_kv_indptr_cpu, paged_kv_indptr=paged_kv_indptr_cpu,
paged_kv_indices, paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_cpu, paged_kv_last_page_len=paged_kv_last_page_len_cpu,
num_qo_heads * dcp_world_size, num_qo_heads=num_qo_heads * dcp_world_size,
num_kv_heads, num_kv_heads=num_kv_heads,
head_dim, head_dim_qk=head_dim,
page_size, page_size=page_size,
causal=False, # This is context run causal=False, # This is context run
sm_scale=sm_scale, sm_scale=sm_scale,
window_left=window_left, window_left=window_left,
...@@ -818,6 +818,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -818,6 +818,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
page_size, page_size,
paged_kv_last_page_len_np, paged_kv_last_page_len_np,
) )
self.paged_kv_last_page_len.gpu[:num_reqs].copy_(
self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True
)
return paged_kv_indices return paged_kv_indices
def build( def build(
...@@ -999,14 +1002,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -999,14 +1002,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan( attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu], qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu],
[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu],
[shared_kv_page_indices_cpu, paged_kv_indices], paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices],
[shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], paged_kv_last_page_len=[
self.num_qo_heads, shared_kv_last_page_len_cpu,
self.num_kv_heads, paged_kv_last_page_len_cpu,
self.head_dim, ],
self.page_size, num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
causal=True, causal=True,
sm_scale=self.sm_scale, sm_scale=self.sm_scale,
window_left=self.window_left, window_left=self.window_left,
...@@ -1084,14 +1090,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -1084,14 +1090,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
) )
prefill_wrapper.plan( prefill_wrapper.plan(
qo_indptr_prefill_cpu, qo_indptr=qo_indptr_prefill_cpu,
paged_kv_indptr_prefill_cpu, paged_kv_indptr=paged_kv_indptr_prefill_cpu,
paged_kv_indices, paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len_prefill_cpu, paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu,
self.num_qo_heads, num_qo_heads=self.num_qo_heads,
self.num_kv_heads, num_kv_heads=self.num_kv_heads,
self.head_dim, head_dim_qk=self.head_dim,
self.page_size, page_size=self.page_size,
causal=True, causal=True,
sm_scale=self.sm_scale, sm_scale=self.sm_scale,
window_left=self.window_left, window_left=self.window_left,
...@@ -1132,14 +1138,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -1132,14 +1138,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# in atten_metadata when using cudagraph. # in atten_metadata when using cudagraph.
fast_plan_decode( fast_plan_decode(
decode_wrapper, decode_wrapper,
self.paged_kv_indptr.cpu[: num_input_tokens + 1], indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1],
paged_kv_indices, indices=paged_kv_indices,
self.paged_kv_last_page_len.cpu[:num_input_tokens], last_page_len_cpu=self.paged_kv_last_page_len.cpu[
seq_lens_cpu[:num_input_tokens], :num_input_tokens
self.num_qo_heads * self.dcp_world_size, ],
self.num_kv_heads, num_qo_heads=self.num_qo_heads * self.dcp_world_size,
self.head_dim, num_kv_heads=self.num_kv_heads,
self.page_size, head_dim=self.head_dim,
page_size=self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope. # Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE", pos_encoding_mode="NONE",
sm_scale=self.sm_scale, sm_scale=self.sm_scale,
...@@ -1617,7 +1624,6 @@ def fast_plan_decode( ...@@ -1617,7 +1624,6 @@ def fast_plan_decode(
indptr_cpu: torch.Tensor, indptr_cpu: torch.Tensor,
indices: torch.Tensor, indices: torch.Tensor,
last_page_len_cpu: torch.Tensor, last_page_len_cpu: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_qo_heads: int, num_qo_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_dim: int, head_dim: int,
...@@ -1654,111 +1660,57 @@ def fast_plan_decode( ...@@ -1654,111 +1660,57 @@ def fast_plan_decode(
# this warm up is to generate the _cached_module for the decode wrapper. # this warm up is to generate the _cached_module for the decode wrapper.
if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
self.plan( self.plan(
indptr_cpu, indptr=indptr_cpu,
indices, indices=indices,
last_page_len_cpu, last_page_len=last_page_len_cpu,
num_qo_heads, num_qo_heads=num_qo_heads,
num_kv_heads, num_kv_heads=num_kv_heads,
head_dim, head_dim=head_dim,
page_size, page_size=page_size,
pos_encoding_mode, pos_encoding_mode=pos_encoding_mode,
window_left, window_left=window_left,
logits_soft_cap, logits_soft_cap=logits_soft_cap,
q_data_type, q_data_type=q_data_type,
kv_data_type, kv_data_type=kv_data_type,
o_data_type, o_data_type=o_data_type,
data_type, data_type=data_type,
sm_scale, sm_scale=sm_scale,
rope_scale, rope_scale=rope_scale,
rope_theta, rope_theta=rope_theta,
non_blocking, non_blocking=non_blocking,
None, # block_tables block_tables=None,
None, # seq_lens seq_lens=None,
fixed_split_size, fixed_split_size=fixed_split_size,
disable_split_kv, disable_split_kv=disable_split_kv,
) )
self.vllm_first_call = False self.vllm_first_call = False
return return
assert self.is_cuda_graph_enabled, "Should be cudagraph only here" assert self.is_cuda_graph_enabled, "Should be cudagraph only here"
batch_size = len(last_page_len_cpu) fast_decode_plan(
if logits_soft_cap is None: self,
logits_soft_cap = 0.0 indptr=indptr_cpu,
indices=indices,
# Handle data types consistently last_page_len=last_page_len_cpu,
if data_type is not None: num_qo_heads=num_qo_heads,
if q_data_type is None: num_kv_heads=num_kv_heads,
q_data_type = data_type head_dim=head_dim,
if kv_data_type is None: page_size=page_size,
kv_data_type = data_type pos_encoding_mode=pos_encoding_mode,
elif q_data_type is None: window_left=window_left,
q_data_type = "float16" logits_soft_cap=logits_soft_cap,
q_data_type=q_data_type,
if kv_data_type is None: kv_data_type=kv_data_type,
kv_data_type = q_data_type data_type=data_type,
q_data_type = ( sm_scale=sm_scale,
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type rope_scale=rope_scale,
) rope_theta=rope_theta,
kv_data_type = ( non_blocking=non_blocking,
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
) )
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime "
"batch size {} mismatches the batch size set during "
"initialization {}".format(batch_size, self._fixed_batch_size)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# host-to-device copy for the indptr buffer
self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True)
# host-to-device copy for the last_page_len buffer
self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True)
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try:
# Make sure we pass exactly 19 arguments for tensor core version
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_cpu,
seq_lens_cpu,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
window_left,
]
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta
@triton.jit @triton.jit
def _copy_page_indices_kernel( def _copy_page_indices_kernel(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment