Unverified Commit 6578e873 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Optimize input preparation for FlashInfer [2/N] (#23174)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 5bd9f841
...@@ -6,6 +6,7 @@ from __future__ import annotations ...@@ -6,6 +6,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional, Union from typing import ClassVar, Optional, Union
import numpy as np
import torch import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
...@@ -22,6 +23,7 @@ from vllm.logger import init_logger ...@@ -22,6 +23,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant) QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention, from vllm.utils.flashinfer import (supports_trtllm_attention,
use_trtllm_attention) use_trtllm_attention)
...@@ -230,6 +232,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -230,6 +232,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indices_cpu = torch.zeros(max_num_pages, self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
...@@ -238,10 +241,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -238,10 +241,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.paged_kv_last_page_len_np = (
self.block_table_arange = torch.arange(max_num_pages_per_req, self.paged_kv_last_page_len_cpu.numpy())
dtype=torch.int32,
device=self.device)
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
...@@ -317,9 +318,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -317,9 +318,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len = common_attn_metadata.max_seq_len max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
use_cascade = common_prefix_len > 0 use_cascade = common_prefix_len > 0
if use_cascade: if use_cascade:
...@@ -342,37 +344,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -342,37 +344,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Remove the blocks of the shared prefix from all requests. # Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds_cpu -= num_common_kv_blocks num_blocks_np -= num_common_kv_blocks
else: else:
shared_qo_indptr_cpu = None shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None shared_kv_last_page_len_cpu = None
max_num_blocks = block_table_bounds_cpu.max().item() # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
block_table_bounds = block_table_bounds_cpu.to(self.device, np.cumsum(
non_blocking=True) num_blocks_np,
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) dtype=np.int32,
< block_table_bounds.unsqueeze(1)) out=self.paged_kv_indptr_np[1:num_reqs + 1],
)
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
non_blocking=True)
# write self.paged_kv_indices inplace # write self.paged_kv_indices inplace
num_actual_pages = torch.sum(mask) num_actual_pages = num_blocks_np.sum().item()
paged_kv_indices = self.paged_kv_indices[:num_actual_pages] paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks], _copy_page_indices_kernel[(num_reqs, )](
mask, paged_kv_indices,
out=paged_kv_indices) block_table_tensor,
block_table_tensor.stride(0),
# write self.paged_kv_indptr_cpu inplace (0-index is always 0) paged_kv_indptr,
torch.cumsum(block_table_bounds_cpu, BLOCK_SIZE=1024,
dim=0, )
dtype=torch.int32,
out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
# write self.paged_kv_last_page_len_cpu inplace # write self.paged_kv_last_page_len_cpu inplace
torch.where(paged_kv_last_page_len_cpu == 0, paged_kv_last_page_len_np = seq_lens_np % page_size
torch.tensor(page_size), self.paged_kv_last_page_len_np[:num_reqs] = np.where(
paged_kv_last_page_len_cpu, paged_kv_last_page_len_np == 0,
out=self.paged_kv_last_page_len_cpu[:num_reqs]) page_size,
paged_kv_last_page_len_np,
)
# Check if any layer uses sinks (requires TRTLLM attention) # Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks has_sinks = self.global_hyperparameters.has_sinks
...@@ -1002,3 +1008,25 @@ def fast_plan_decode( ...@@ -1002,3 +1008,25 @@ def fast_plan_decode(
self._sm_scale = sm_scale self._sm_scale = sm_scale
self._rope_scale = rope_scale self._rope_scale = rope_scale
self._rope_theta = rope_theta self._rope_theta = rope_theta
@triton.jit
def _copy_page_indices_kernel(
page_indices,
block_table,
block_table_stride,
cu_num_blocks,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = block_table + req_idx * block_table_stride
start_idx = tl.load(cu_num_blocks + req_idx)
end_idx = tl.load(cu_num_blocks + req_idx + 1)
num_blocks = end_idx - start_idx
offset = tl.arange(0, BLOCK_SIZE)
for i in tl.range(0, num_blocks, BLOCK_SIZE):
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
tl.store(page_indices + start_idx + i + offset,
block_ids,
mask=i + offset < num_blocks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment