Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
logger = init_logger(__name__)
......@@ -50,11 +50,13 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
"""CPU attention supports decoder,
encoder-only and encoder-decoder attention."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod
......@@ -136,6 +138,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec)
def build(
self,
......@@ -151,7 +154,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
causal = False if self.is_cross_attention else common_attn_metadata.causal
sdpa_start_loc = query_start_loc
num_decode_tokens = 0
......@@ -171,9 +174,6 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
query_start_loc = query_start_loc[: num_decodes + 1]
block_table_tensor = block_table_tensor[:num_decodes]
sheduler_metadata = None
if causal:
# for decode batch, use the custom kernel
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
......
......@@ -429,6 +429,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.attention_config = vllm_config.attention_config
self._workspace_buffer = None
self._prefill_wrapper: (
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
......@@ -563,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
if self.head_dim == 256 and current_platform.is_device_capability(100):
if self.head_dim == 256 and current_platform.is_device_capability_family(100):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
# head size 256 and block size 16 is not supported on blackwell.
assert kv_cache_spec.block_size != 16, (
......@@ -779,6 +780,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.cache_dtype,
self.q_data_type,
is_prefill=True,
force_use_trtllm=self.attention_config.use_trtllm_attention,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
......
......@@ -211,7 +211,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks)
index = torch.argsort(spec_token_masks, stable=True)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
......
......@@ -446,7 +446,7 @@ def use_flashinfer_prefill() -> bool:
and flashinfer_available
and not vllm_config.attention_config.use_cudnn_prefill
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability(100)
and current_platform.is_device_capability_family(100)
)
......@@ -457,7 +457,7 @@ def use_cudnn_prefill() -> bool:
return (
flashinfer_available
and vllm_config.attention_config.use_cudnn_prefill
and current_platform.is_device_capability(100)
and current_platform.is_device_capability_family(100)
and has_nvidia_artifactory()
)
......@@ -470,7 +470,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
return (
flashinfer_available
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability(100)
and current_platform.is_device_capability_family(100)
)
......@@ -1787,6 +1787,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _concat_k_nope_k_pe(
self, k_nope: torch.Tensor, k_pe: torch.Tensor
) -> torch.Tensor:
"""
Efficiently concatenate k_nope and k_pe tensors along the last dimension.
This function avoids the performance penalty of torch.cat with expanded
non-contiguous tensors by pre-allocating the output and using direct copies.
Args:
k_nope: Tensor of shape [..., nope_dim]
k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim]
or [..., pe_dim]
Returns:
Tensor of shape [..., nope_dim + pe_dim]
"""
k = torch.empty(
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
dtype=k_nope.dtype,
device=k_nope.device,
)
# Direct copies with efficient broadcasting
k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe
return k
def _compute_prefill_context(
self,
q: torch.Tensor,
......@@ -1823,7 +1850,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
......@@ -1927,7 +1954,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
......@@ -1976,7 +2003,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe)
output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
......
......@@ -18,7 +18,7 @@ from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -30,13 +30,31 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
# For FP8 sparse attention we have two impelementations:
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
# done by treating all tokens as single batch.
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
# the FP8 decode kernel for decode.
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
# prefill kernel requires padding the numer of heads to 128 while the decode does not
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
# batch mode (#2).
MIN_HEADS_FOR_BF16_PREFILL = 32
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
......@@ -127,19 +145,72 @@ class FlashMLASparseMetadata:
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
fp8_extra_metadata: FP8KernelMetadata | None = None
@dataclass
class FP8SeperatePrefillDecode:
@dataclass
class Decode:
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
decode_query_len: int # needed for reshape in spec decode
@dataclass
class Prefill:
# Sequence lengths (context + query) for prefill requests
# Shape: [num_prefill_reqs]
seq_lens: torch.Tensor
# Request ID for each token: -1 for decode tokens, request index
# (0, 1, 2, ...) for prefill tokens.
# Shape: [num_actual_tokens]
request_ids: torch.Tensor
# Workspace start offsets for all prefill requests
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
# 0-indexed within each chunk. Used to map prefill tokens to workspace
# offsets in convert_logical_index_to_physical_index
workspace_starts: torch.Tensor
@dataclass
class Chunk:
"""Metadata for a chunk of prefill requests.
Prefill requests may be chunked to fit within the fixed workspace size.
"""
seq_lens: torch.Tensor
tokens_slice: slice
block_table: torch.Tensor
req_start_idx: int
workspace_starts: torch.Tensor
chunk_tot_seqlen: int
chunks: list[Chunk]
num_prefills: int = 0
num_decodes: int = 0
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
decode: Decode | None = None
prefill: Prefill | None = None
fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None
fp8_use_mixed_batch: bool = False
# Kernel with prefill workspace support
@triton.jit
def _convert_req_index_to_global_index_kernel(
req_id_ptr, # int32 [num_tokens]
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
# shapes (compile-time where possible)
max_num_blocks_per_req: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, # tile width along columns
HAS_PREFILL: tl.constexpr,
# strides (in elements)
bt_stride0,
bt_stride1,
......@@ -165,7 +236,10 @@ def _convert_req_index_to_global_index_kernel(
# Only token == -1 should propagate as -1
is_invalid_tok = tok < 0
is_prefill = False
if HAS_PREFILL:
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
is_prefill = prefill_req_id >= 0
# Compute block id and in-block offset
block_id = tok // BLOCK_SIZE
inblock_off = tok % BLOCK_SIZE
......@@ -173,12 +247,18 @@ def _convert_req_index_to_global_index_kernel(
# Guard block_table access
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
out_val = tl.where(
is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off
is_invalid_tok |= ~valid_block
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
out_val = base * BLOCK_SIZE + inblock_off
# Override with prefill output if prefill is enabled
if HAS_PREFILL:
workspace_start = tl.load(
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
)
prefill_out = workspace_start + tok
out_val = tl.where(is_prefill, prefill_out, out_val)
out_val = tl.where(is_invalid_tok, -1, out_val)
# Store results
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
......@@ -192,6 +272,9 @@ def triton_convert_req_index_to_global_index(
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
BLOCK_N: int = 128, # tile width along columns
HAS_PREFILL_WORKSPACE: bool = False,
prefill_workspace_request_ids: torch.Tensor | None = None,
prefill_workspace_starts: torch.Tensor | None = None,
):
"""
out[token_id, indice_id] =
......@@ -202,17 +285,32 @@ def triton_convert_req_index_to_global_index(
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
instead of global cache slots. prefill_workspace_request_ids and
prefill_workspace_starts must be provided.
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
prefill request index (maps to prefill_workspace_starts)
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
starts for each prefill request
"""
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})"
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
)
if HAS_PREFILL_WORKSPACE:
assert prefill_workspace_request_ids is not None
assert prefill_workspace_starts is not None
assert prefill_workspace_request_ids.dtype == torch.int32
assert prefill_workspace_starts.dtype == torch.int32
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
max_num_blocks_per_req = block_table.shape[1]
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
# Ensure contiguous tensors on the same device
......@@ -226,6 +324,13 @@ def triton_convert_req_index_to_global_index(
ti_stride0, ti_stride1 = token_indices_c.stride()
out_stride0, out_stride1 = out.stride()
# Prepare prefill pointers
if HAS_PREFILL_WORKSPACE:
assert prefill_workspace_request_ids is not None # for mypy
assert prefill_workspace_starts is not None # for mypy
assert prefill_workspace_request_ids.is_contiguous()
assert prefill_workspace_starts.is_contiguous()
# Exact 2D grid: tokens × column tiles
grid = (num_tokens, tiles_per_row)
......@@ -234,10 +339,13 @@ def triton_convert_req_index_to_global_index(
block_table_c,
token_indices_c,
out,
prefill_workspace_request_ids,
prefill_workspace_starts,
# shapes / constexprs
max_num_blocks_per_req,
BLOCK_SIZE,
BLOCK_N,
HAS_PREFILL_WORKSPACE,
# strides
bt_stride0,
bt_stride1,
......@@ -249,7 +357,16 @@ def triton_convert_req_index_to_global_index(
return out
@dataclass
def get_prefill_workspace_size(max_model_len: int):
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
# May be tuned later.
# Memory usage: 5 * max_model_len * 576 * 2 bytes
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
# 5 * 163840 * 576 * 2 = ~900 MB
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
return max_model_len * 5
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
......@@ -259,29 +376,42 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
) -> None:
self.vllm_config = vllm_config
self.layer_names = layer_names
cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
props = torch.cuda.get_device_properties(device)
sm_count = props.multi_processor_count
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
self.topk_tokens_tensor = torch.tensor(
[self.topk_tokens], device=device, dtype=torch.int32
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
self.topk_tokens_tensor = torch.full(
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
)
self.max_model_len_tensor = torch.tensor(
[self.model_config.max_model_len], device=device, dtype=torch.int32
# Shape: [max_num_seqs], all elements = max_model_len
self.max_model_len_tensor = torch.full(
(max_num_seqs,),
self.model_config.max_model_len,
device=device,
dtype=torch.int32,
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty(
(1, 1), dtype=torch.int32, device=self.device
(max_num_seqs, 1), dtype=torch.int32, device=self.device
)
# Equation taken from FlashMLA/csrc/pybind.cpp
......@@ -290,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_num_sm_parts = int(
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
)
if current_platform.is_device_capability(100):
if current_platform.is_device_capability_family(100):
max_num_sm_parts *= 2
self.tile_scheduler_metadata_buffer = torch.empty(
# TileSchedulerMetaDataSize = 8
......@@ -299,10 +429,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
dtype=torch.int32,
device=device,
)
# Sized for per-request batching (num_decodes + 1)
self.num_splits_buffer = torch.empty(
# We pack all the tokens into one batch for sparse attention.
# Otherwise, we can exceed the sm of `get_mla_metadata`.
(2,),
(max_num_seqs + 1,),
dtype=torch.int32,
device=device,
)
......@@ -312,30 +441,171 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
device=device,
)
def build(
def _build_fp8_mixed_decode_prefill(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashMLASparseMetadata:
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
"""Build FP8 metadata treating all tokens as one mixed batch.
This matches main branch's approach and avoids the BF16 prefill kernel
which has head padding overhead when num_heads is small (high TP case).
"""
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
# Build metadata for all tokens as a single batch
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
num_sm_parts = tile_scheduler_metadata.size(0)
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
num_splits_view = self.num_splits_buffer[:2]
num_splits_view.copy_(num_splits)
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=num_splits_view,
cache_lens=self.max_model_len_tensor[:1],
dummy_block_table=self.dummy_block_table[:1],
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
return fp8_metadata
def _build_fp8_separate_prefill_decode(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode":
num_tokens = common_attn_metadata.num_actual_tokens
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode
fp8_metadata = FP8Meta(
num_decodes=num_decodes,
num_prefills=num_prefills,
num_decode_tokens=num_decode_tokens,
num_prefill_tokens=num_prefill_tokens,
)
# Extract prefill sequence lengths (context + query, not just query)
# Decode requests come first in the batch, prefill requests follow
prefill_seq_lens = None
prefill_request_id = None
prefill_workspace_starts = None
prefill_chunks = None
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
prefill_seq_lens = seq_lens[num_decodes:]
# Build prefill_request_id: -1 for decode, request index for
# prefill. This enables a single
# convert_logical_index_to_physical_index call for all tokens
prefill_request_id = torch.full(
(num_tokens,), -1, dtype=torch.int32, device=self.device
)
# Map prefill tokens to their request IDs (0, 1, 2, ...)
for req_idx in range(num_prefills):
# Get query token range for this prefill request
global_req_idx = num_decodes + req_idx
req_query_start = query_start_loc_cpu[global_req_idx]
req_query_end = query_start_loc_cpu[global_req_idx + 1]
prefill_request_id[req_query_start:req_query_end] = req_idx
# will be adjusted by chunk loop
prefill_workspace_starts_cpu = torch.zeros(
num_prefills, dtype=torch.int32, pin_memory=True
)
prefill_workspace_starts_cpu[1:] = torch.cumsum(
prefill_seq_lens_cpu[:-1], dim=0
)
# populated by non-blocking copy after prefill_workspace_starts_cpu is
# updated by each chunk
prefill_workspace_starts = torch.empty(
num_prefills, dtype=torch.int32, device=self.device
)
# Chunk prefill requests to fit within workspace size
max_prefill_buffer_size = get_prefill_workspace_size(
self.vllm_config.model_config.max_model_len
)
chunk_bounds = split_prefill_chunks(
prefill_seq_lens_cpu, max_prefill_buffer_size
)
prefill_chunks = []
for chunk_start, chunk_end in chunk_bounds:
# Adjust workspace_starts in-place per chunk to be
# 0-indexed within each chunk
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
# Initial: workspace_starts=[0,10,25,45]
# After: workspace_starts=[0,10,0,20]
# (chunk 0 starts at 0, chunk 1 starts at 0)
offset = prefill_workspace_starts_cpu[chunk_start].item()
prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset
chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
tokens_slice = slice(token_start, token_end)
# Create chunk view of gpu tensor
chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
chunk_block_table = common_attn_metadata.block_table_tensor[
num_decodes + chunk_start : num_decodes + chunk_end
]
prefill_chunks.append(
FP8Meta.Prefill.Chunk(
seq_lens=chunk_seq_lens,
tokens_slice=tokens_slice,
block_table=chunk_block_table,
req_start_idx=chunk_start,
workspace_starts=chunk_workspace_starts,
chunk_tot_seqlen=chunk_tot_seqlen,
)
)
prefill_workspace_starts.copy_(
prefill_workspace_starts_cpu, non_blocking=True
)
fp8_metadata.prefill = FP8Meta.Prefill(
seq_lens=prefill_seq_lens,
request_ids=prefill_request_id,
workspace_starts=prefill_workspace_starts,
chunks=prefill_chunks,
)
if num_decodes > 0:
# Compute decode_query_len for spec decode (uniform due to require_uniform)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
num_q_tokens_per_head_k=num_tokens * self.num_heads,
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
......@@ -348,33 +618,70 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
self.num_splits_buffer.copy_(num_splits)
# num_splits has size [num_decodes + 1]
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
num_splits_view.copy_(num_splits)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=self.num_splits_buffer,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=self.max_model_len_tensor,
dummy_block_table=self.dummy_block_table,
num_splits=num_splits_view,
dummy_block_table=self.dummy_block_table[:num_decodes],
cache_lens=self.max_model_len_tensor[:num_decodes],
)
fp8_metadata.decode = FP8Meta.Decode(
kernel_metadata=kernel_meta,
decode_query_len=decode_query_len,
)
return fp8_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashMLASparseMetadata:
cm = common_attn_metadata
num_tokens = cm.num_actual_tokens
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata: (
FlashMLASparseMetadata.FP8SeperatePrefillDecode
| FlashMLASparseMetadata.FP8KernelMetadata
| None
) = None
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
if self.use_fp8_kv_cache:
if fp8_use_mixed_batch:
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
else:
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
metadata = FlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
num_reqs=cm.num_reqs,
max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens,
query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
fp8_use_mixed_batch=fp8_use_mixed_batch,
)
return metadata
......@@ -412,7 +719,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability(100) else 64
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()
assert vllm_config is not None and vllm_config.model_config is not None
prefill_workspace_size = get_prefill_workspace_size(
vllm_config.model_config.max_model_len
)
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
(self.prefill_bf16_workspace,) = (
current_workspace_manager().get_simultaneous(
(self.prefill_workspace_shape, torch.bfloat16)
)
)
def _forward_bf16_kv(
self,
......@@ -420,6 +741,184 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
def _forward_fp8_kv_separate_prefill_decode(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
num_decodes = fp8_metadata.num_decodes
prefill_request_ids = None
prefill_workspace_starts = None
has_prefill_workspace = False
if fp8_metadata.prefill is not None:
prefill_request_ids = fp8_metadata.prefill.request_ids
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
has_prefill_workspace = True
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
# For BF16 cache: always use global cache slots (no workspace)
# prefill_workspace_starts has been adjusted in-place per chunk so
# prefill indices automatically come out chunk-local
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
HAS_PREFILL_WORKSPACE=has_prefill_workspace,
prefill_workspace_request_ids=prefill_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q = reshape_query_for_spec_decode(q, num_decodes)
seq_len = q.shape[1]
# Reshape topk_indices: (num_decode_tokens, topk)
# -> (num_decodes, seq_len, topk)
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
assert fp8_metadata.decode is not None
attn_out, _ = self._fp8_flash_mla_kernel(
q=q,
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
topk_indices=topk_indices,
kernel_metadata=fp8_metadata.decode.kernel_metadata,
)
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
# -> (num_decode_tokens, num_heads, head_dim_v)
return reshape_attn_output_for_spec_decode(attn_out)
num_decode_tokens = fp8_metadata.num_decode_tokens
num_prefill_tokens = fp8_metadata.num_prefill_tokens
# Pure decode: direct call without allocation
if num_decode_tokens > 0 and num_prefill_tokens == 0:
assert fp8_metadata.decode is not None
attn_out = _fp8_decode(q, topk_indices)
else:
# Mixed or pure prefill: allocate output tensor
attn_out = q.new_empty(
(attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
if num_decode_tokens > 0:
attn_out[:num_decode_tokens] = _fp8_decode(
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
)
assert fp8_metadata.prefill is not None
for chunk in fp8_metadata.prefill.chunks:
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
ops.cp_gather_and_upconvert_fp8_kv_cache(
kv_c_and_k_pe_cache,
chunk_workspace,
chunk.block_table,
chunk.seq_lens,
chunk.workspace_starts,
len(chunk.block_table),
)
chunk_q = q[chunk.tokens_slice]
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
chunk_q,
chunk_workspace,
chunk_topk_indices_workspace,
)
return attn_out
def _forward_fp8_kv_mixed_batch(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
"""Mixed batch FP8 forward path that treats all tokens as one batch.
This is equivalent to main branch's approach and avoids the BF16
prefill kernel which has head padding overhead when num_heads is small.
Used when use_mixed_batch is True.
"""
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
assert attn_metadata.fp8_extra_metadata is not None
assert isinstance(
attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
)
fp8_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = self._fp8_flash_mla_kernel(
q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk)
kernel_metadata=fp8_metadata,
)
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
return _attn_out.squeeze(0)
def _fp8_flash_mla_kernel(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
) -> torch.Tensor:
return flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=kernel_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=kernel_metadata.cache_lens,
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
num_splits=kernel_metadata.num_splits,
is_fp8_kvcache=True,
indices=topk_indices,
softmax_scale=self.softmax_scale,
)
def _bf16_flash_mla_kernel(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
......@@ -445,31 +944,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
output = output[:, : self.num_heads, :]
return output
def _forward_fp8_kv(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
assert attn_metadata.fp8_extra_metadata is not None
extra_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = flash_mla_with_kvcache(
q=q.unsqueeze(0), # unsqueeze to add batch_dim
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=extra_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=extra_metadata.cache_lens,
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
num_splits=extra_metadata.num_splits,
is_fp8_kvcache=True,
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
softmax_scale=self.softmax_scale,
)
return _attn_out
def forward(
self,
layer: AttentionLayer,
......@@ -477,7 +951,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
attn_metadata: FlashMLASparseMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
......@@ -493,6 +967,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
)
if attn_metadata is None:
# Dummy run - no need to allocate buffers
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
......@@ -505,6 +980,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
topk_indices = self.topk_indices_buffer[:num_actual_toks]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
......@@ -514,16 +990,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
# TODO: handle index / kv_cache correctly
topk_indices_global = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
q = torch.cat([ql_nope, q_pe], dim=-1)
......@@ -538,13 +1005,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
scale=layer._k_scale,
)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(
q, kv_cache, topk_indices_global, attn_metadata
if not use_fp8_cache:
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
elif attn_metadata.fp8_use_mixed_batch:
attn_out = self._forward_fp8_kv_mixed_batch(
q, kv_cache, topk_indices, attn_metadata
)
else:
attn_out = self._forward_fp8_kv(
q, kv_cache, topk_indices_global, attn_metadata
attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_cache, topk_indices, attn_metadata
)
self._v_up_proj(attn_out, out=output[:num_actual_toks])
......
......@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
split_prefill_chunks,
)
logger = init_logger(__name__)
......@@ -176,40 +177,15 @@ def kv_spans_from_batches(
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# May be tuned later.
return max_model_len * 2
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int
) -> list[tuple[int, int]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
assert cur_seq_len <= max_prefill_buffer_size
total_seq_lens += cur_seq_len
if total_seq_lens > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
# within the flashmla_sparse workspace.
# For DeepSeek-V3.2, the max_model_len is 163840.
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
return max_model_len * 40
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
......@@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.seq_lens_cpu[num_decodes:],
self.max_prefill_buffer_size,
num_decodes,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
......
......@@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......@@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
# constants
MIN_LAUNCH_GRID_SIZE_2D = 128 # Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS = 16 # Number of parallel tiled softmax segments
@dataclass
class TritonAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
......@@ -54,6 +60,12 @@ class TritonAttentionMetadata:
block_table: torch.Tensor
slot_mapping: torch.Tensor
seq_threshold_3D: int
num_par_softmax_segments: int
softmax_segm_output: torch.Tensor
softmax_segm_max: torch.Tensor
softmax_segm_expsum: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
......@@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
# Check if CUDA Graphs are enabled for decode
self.decode_cudagraph_enabled = (
self.vllm_config.compilation_config.cudagraph_mode
in (
CUDAGraphMode.FULL_AND_PIECEWISE,
CUDAGraphMode.FULL_DECODE_ONLY,
CUDAGraphMode.FULL,
)
)
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
# A lower bound for num_q_blocks is the number of sequences.
# To ensure the minimum launch grid size is achieved, the number of sequences
# must be at least equal to the threshold below.
# If this threshold is not reached (i.e., the batch size is not large enough),
# the 3D kernel will be selected instead.
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
# Modify the threshold if needed.
if self.decode_cudagraph_enabled:
capture_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
assert capture_sizes, "CUDA Graphs enabled but no capture sizes specified."
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
self.seq_threshold_3D = min(
capture_sizes,
key=lambda x: abs(x - self.seq_threshold_3D),
)
self.num_par_softmax_segments = NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded = next_power_of_2(self.headdim)
self.softmax_segm_output = torch.empty(
(
self.seq_threshold_3D,
self.num_heads_q,
self.num_par_softmax_segments,
headdim_padded,
),
dtype=torch.float32,
device=device,
)
self.softmax_segm_max = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
self.softmax_segm_expsum = torch.empty(
(self.seq_threshold_3D, self.num_heads_q, self.num_par_softmax_segments),
dtype=torch.float32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
......@@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
seq_threshold_3D=self.seq_threshold_3D,
num_par_softmax_segments=self.num_par_softmax_segments,
softmax_segm_output=self.softmax_segm_output,
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
return attn_metadata
......@@ -349,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
seq_threshold_3D = attn_metadata.seq_threshold_3D
num_par_softmax_segments = attn_metadata.num_par_softmax_segments
softmax_segm_output = attn_metadata.softmax_segm_output
softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
unified_attention(
......@@ -369,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
)
......
......@@ -937,6 +937,33 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
"""
Split the prefill requests into chunks such that the total sequence length
of each chunk is less than or equal to the workspace size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
workspace_size: The maximum workspace size (in tokens) per chunk.
request_offset: The offset to add to the request indices.
Returns:
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
"""
chunk_bounds = []
i, n = 0, len(seq_lens_cpu)
assert torch.all(seq_lens_cpu <= workspace_size).item()
while i < n:
start, chunk_total = i, 0
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
chunk_total += s
i += 1
chunk_bounds.append((start + request_offset, i + request_offset))
return chunk_bounds
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
......
......@@ -397,6 +397,25 @@ class BlockPool:
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
only evicts blocks that are currently cached (have a hash). blocks
with ref_cnt > 0 are not freed from the block pool, only evicted
from the prefix cache hash table.
Args:
block_ids: Set of block IDs to evict from cache.
"""
for block_id in block_ids:
assert block_id < len(self.blocks), (
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
f"This indicates a bug in the KV connector - workers should "
f"only report block IDs that were allocated by the scheduler."
)
block = self.blocks[block_id]
self._maybe_evict_cached_block(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
......
......@@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
encoder embeddings from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
......@@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
......@@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
return True
......@@ -104,7 +110,7 @@ class EncoderCacheManager:
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
......@@ -121,9 +127,9 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be
encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked.
Returns:
......@@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_embeds += num_embeds_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
if num_embeds <= self.num_free_slots:
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
if num_embeds > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
while num_embeds > self.num_free_slots:
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots += num_free_embeds
return True
def allocate(self, request: Request, input_id: int) -> None:
......@@ -176,16 +182,16 @@ class EncoderCacheManager:
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens
assert self.num_freeable_slots >= num_encoder_tokens
assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens
self.num_freeable_slots -= num_encoder_tokens
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
......@@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
......@@ -218,9 +224,9 @@ class EncoderCacheManager:
return
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*.
......@@ -341,3 +347,56 @@ def compute_mm_encoder_budget(
)
return encoder_compute_budget, encoder_cache_size
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
# use the manager for scheduling purposes. Encoder-decoder models will eventually
# utilize the cache and this class will fold into EncoderCacheManager, as
# differences with MM models shrink.
class EncoderDecoderCacheManager(EncoderCacheManager):
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False
def can_allocate(
self,
request: Request,
input_id: int,
encoder_compute_budget: int,
num_embeds_to_schedule: int,
) -> bool:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_encoder_embeds > encoder_compute_budget:
return False
num_encoder_embeds += num_embeds_to_schedule
# Enough free slots
return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
def free(self, request: Request) -> None:
for input_id in range(len(request.mm_features)):
self.free_encoder_input(request, input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
return set(range(len(request.mm_features)))
def get_freed_mm_hashes(self) -> list[str]:
freed = self.freed
self.freed = []
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_encoder_embeds
......@@ -333,6 +333,14 @@ class KVCacheManager:
"""
self.coordinator.free(request.request_id)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
Args:
block_ids: Set of block IDs to evict from cache.
"""
self.block_pool.evict_blocks(block_ids)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,
......
......@@ -687,7 +687,9 @@ def check_enough_kv_cache_memory(
raise ValueError(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
"initializing the engine. "
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
"for more details."
)
max_model_len = vllm_config.model_config.max_model_len
......@@ -711,8 +713,10 @@ def check_enough_kv_cache_memory(
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory / GiB_bytes:.2f} GiB). "
f"{estimated_msg} "
f"Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine."
f"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
f"when initializing the engine. "
f"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
f"for more details."
)
......
......@@ -27,6 +27,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
compute_encoder_budget,
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
......@@ -106,6 +107,7 @@ class Scheduler(SchedulerInterface):
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
self.recompute_kv_load_failures = True
if self.vllm_config.kv_transfer_config is not None:
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors"
......@@ -117,6 +119,10 @@ class Scheduler(SchedulerInterface):
)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()
kv_load_failure_policy = (
self.vllm_config.kv_transfer_config.kv_load_failure_policy
)
self.recompute_kv_load_failures = kv_load_failure_policy == "recompute"
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
......@@ -176,7 +182,11 @@ class Scheduler(SchedulerInterface):
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size)
self.encoder_cache_manager = (
EncoderDecoderCacheManager(cache_size=encoder_cache_size)
if self.is_encoder_decoder
else EncoderCacheManager(cache_size=encoder_cache_size)
)
speculative_config = vllm_config.speculative_config
self.use_eagle = False
......@@ -339,11 +349,11 @@ class Scheduler(SchedulerInterface):
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_tokens_to_restore = sum(
preempted_req.get_num_encoder_tokens(i)
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs
)
encoder_compute_budget += num_tokens_to_restore
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
......@@ -901,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
......@@ -960,9 +971,8 @@ class Scheduler(SchedulerInterface):
):
num_new_tokens = start_pos - num_computed_tokens
break
if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget, num_tokens_to_schedule
request, i, encoder_compute_budget, num_embeds_to_schedule
):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
......@@ -982,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel = max(0, num_computed_tokens - start_pos)
end_idx_rel = min(
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i)
......@@ -1066,7 +1093,7 @@ class Scheduler(SchedulerInterface):
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure
# skip failed or rescheduled requests from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
......@@ -1107,6 +1134,7 @@ class Scheduler(SchedulerInterface):
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
kv_transfer_params = None
status_before_stop = request.status
......@@ -1115,12 +1143,10 @@ class Scheduler(SchedulerInterface):
new_token_ids, stopped = self._update_request_with_output(
request, new_token_ids
)
# Stop checking for pooler models.
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len, pooler_output)
elif request.pooling_params and pooler_output is not None:
# Pooling stops as soon as there is output.
request.status = RequestStatus.FINISHED_STOPPED
stopped = True
if stopped:
kv_transfer_params = self._free_request(request)
......@@ -1177,6 +1203,21 @@ class Scheduler(SchedulerInterface):
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
for request in requests:
outputs[request.client_index].append(
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
events=request.take_events(),
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
)
)
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
......@@ -1610,8 +1651,11 @@ class Scheduler(SchedulerInterface):
self._free_blocks(self.requests[req_id])
def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request], invalid_block_ids: set[int]
) -> tuple[set[str], int]:
self,
requests: Iterable[Request],
invalid_block_ids: set[int],
evict_blocks: bool = True,
) -> tuple[set[str], int, set[int]]:
"""
Identify and update requests affected by invalid KV cache blocks.
......@@ -1623,16 +1667,21 @@ class Scheduler(SchedulerInterface):
Args:
requests: The set of requests to scan for invalid blocks.
invalid_block_ids: IDs of invalid blocks.
evict_blocks: Whether to collect blocks for eviction (False for
async requests which aren't cached yet).
Returns:
tuple:
- affected_req_ids (set[str]): IDs of requests impacted by
invalid blocks.
- total_affected_tokens (int): Total number of tokens that must
be recomputed across all affected requests (for observability).
be recomputed across all affected requests.
- blocks_to_evict (set[int]): Block IDs to evict from cache,
including invalid blocks and downstream dependent blocks.
"""
affected_req_ids: set[str] = set()
total_affected_tokens = 0
blocks_to_evict: set[int] = set()
# If a block is invalid and shared by multiple requests in the batch,
# these requests must be rescheduled, but only the first will recompute
# it. This set tracks blocks already marked for recomputation.
......@@ -1690,6 +1739,9 @@ class Scheduler(SchedulerInterface):
)
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
# collect invalid block and all downstream dependent blocks
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])
if is_affected:
if not marked_invalid_block:
......@@ -1705,47 +1757,70 @@ class Scheduler(SchedulerInterface):
affected_req_ids.add(request.request_id)
return affected_req_ids, total_affected_tokens
return affected_req_ids, total_affected_tokens, blocks_to_evict
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
total_requests_to_reschedule = 0
total_tokens_to_reschedule = 0
"""
Handle requests affected by invalid KV cache blocks.
Returns:
Set of affected request IDs to skip in update_from_output main loop.
"""
should_fail = not self.recompute_kv_load_failures
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
# handle async KV loads (not cached yet, evict_blocks=False)
async_load_reqs = (
req
for req in self.waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
)
async_affected_req_ids, num_tokens_to_reschedule = (
async_failed_req_ids, num_failed_tokens, _ = (
self._update_requests_with_invalid_blocks(
async_load_reqs, invalid_block_ids
async_load_reqs, invalid_block_ids, evict_blocks=False
)
)
total_requests_to_reschedule += len(async_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
total_failed_requests = len(async_failed_req_ids)
total_failed_tokens = num_failed_tokens
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes.
self.failed_recving_kv_req_ids |= async_affected_req_ids
# --- Handle sync KV loads (running requests) ---
sync_affected_req_ids, num_tokens_to_reschedule = (
self._update_requests_with_invalid_blocks(self.running, invalid_block_ids)
# handle sync loads (may be cached, collect blocks for eviction)
sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = (
self._update_requests_with_invalid_blocks(
self.running, invalid_block_ids, evict_blocks=True
)
)
total_failed_requests += len(sync_failed_req_ids)
total_failed_tokens += num_failed_tokens
total_requests_to_reschedule += len(sync_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
if not total_failed_requests:
return set()
# evict invalid blocks and downstream dependent blocks from cache
# only when not using recompute policy (where blocks will be recomputed
# and reused by other requests sharing them)
if sync_blocks_to_evict and not self.recompute_kv_load_failures:
self.kv_cache_manager.evict_blocks(sync_blocks_to_evict)
if should_fail:
all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids
logger.error(
"Failing %d request(s) due to KV load failure "
"(failure_policy=fail, %d tokens affected). Request IDs: %s",
total_failed_requests,
total_failed_tokens,
all_failed_req_ids,
)
return all_failed_req_ids
if total_requests_to_reschedule:
logger.warning(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected).",
total_requests_to_reschedule,
total_tokens_to_reschedule,
total_failed_requests,
total_failed_tokens,
)
# Return the IDs of affected running requests to skip in
# update_from_output.
return sync_affected_req_ids
# Mark async requests with KV load failures for retry once loading completes
self.failed_recving_kv_req_ids |= async_failed_req_ids
# Return sync affected IDs to skip in update_from_output
return sync_failed_req_ids
......@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import torch
from vllm.v1.request import Request, RequestStatus
......@@ -39,14 +37,8 @@ def remove_all(lst: list, items_to_remove: set) -> list:
return [item for item in lst if item not in items_to_remove]
def check_stop(
request: Request, max_model_len: int, pooler_output: torch.Tensor | None = None
) -> bool:
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
def check_stop(request: Request, max_model_len: int) -> bool:
assert not request.pooling_params
sampling_params = request.sampling_params
assert sampling_params is not None
......
......@@ -19,24 +19,27 @@ from vllm.v1.serial_utils import UtilityResult
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
class FinishReason(enum.IntEnum):
"""
Reason a request finished - stop, length, or abort.
Reason a request finished - stop, length, abort, or error.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
abort - aborted by client
error - retryable request-level internal error (e.g., KV load failure).
Invariant: always converted to 500 Internal Server Error.
"""
STOP = 0
LENGTH = 1
ABORT = 2
ERROR = 3
def __str__(self):
return FINISH_REASON_STRINGS[self.value]
......
......@@ -26,7 +26,7 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.usage.usage_lib import UsageContext
......@@ -111,7 +111,7 @@ class AsyncLLM(EngineClient):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_config(self.model_config)
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
......@@ -192,7 +192,7 @@ class AsyncLLM(EngineClient):
@property
@deprecated(
"`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
"The old name will be removed in v0.13."
"The old name will be removed in v0.14."
)
def processor(self):
return self.input_processor
......@@ -701,10 +701,6 @@ class AsyncLLM(EngineClient):
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_processor.tokenizer = tokenizer
async def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
......
......@@ -211,6 +211,9 @@ class EngineCore:
freeze_gc_heap()
# If enable, attach GC debugger after static variable freeze.
maybe_attach_gc_debug_callback()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache()
def _initialize_kv_caches(
self, vllm_config: VllmConfig
......@@ -672,10 +675,6 @@ class EngineCoreProc(EngineCore):
assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...")
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache()
@contextmanager
def _perform_handshakes(
self,
......
......@@ -19,7 +19,8 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
......@@ -64,10 +65,6 @@ class InputProcessor:
def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_preprocessor.tokenizer = tokenizer
def _validate_logprobs(
self,
params: SamplingParams,
......@@ -192,29 +189,39 @@ class InputProcessor:
def _validate_single_prompt(single_prompt: dict | str) -> None:
if not isinstance(single_prompt, dict):
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
import torch
def _get_len(items: object):
if isinstance(items, dict): # Embedding inputs
return _get_len(next(iter(items.values()))) if items else 1
if isinstance(items, list):
return len(items)
if isinstance(items, torch.Tensor):
# To keep backwards compatibility for single item embedding input
return 1 if getattr(items, "_is_single_item", False) else len(items)
return 1
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = len(items) if isinstance(items, list) else 1
uuid_len = (
len(mm_uuids[modality])
if isinstance(mm_uuids[modality], list)
else 1
)
data_len = _get_len(items)
uuid_len = _get_len(mm_uuids[modality])
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' "
f"multi_modal_uuids for modality {modality!r} "
"must have same length as data: got "
f"{uuid_len} uuids vs "
f"{data_len} items."
f"{uuid_len} uuids vs {data_len} items."
)
else:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' must "
f"multi_modal_uuids for modality {modality!r} must "
"be provided if multi_modal_data is provided."
)
......
......@@ -23,7 +23,7 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tracing import init_tracer
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
......@@ -86,7 +86,7 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_config(self.model_config)
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
......@@ -139,7 +139,7 @@ class LLMEngine:
@property
@deprecated(
"`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. "
"The old name will be removed in v0.13."
"The old name will be removed in v0.14."
)
def processor(self):
return self.input_processor
......@@ -358,10 +358,6 @@ class LLMEngine:
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_processor.tokenizer = tokenizer
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
......
......@@ -10,7 +10,7 @@ def __getattr__(name: str):
warnings.warn(
"`vllm.v1.engine.processor.Processor` has been moved to "
"`vllm.v1.engine.input_processor.InputProcessor`. "
"The old name will be removed in v0.13.",
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment