Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1030 additions
and
289 deletions
+1030
-289
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+19
-19
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+3
-1
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+1
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+33
-6
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+569
-100
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+12
-36
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+83
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+27
-0
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+19
-0
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+87
-28
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+8
-0
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+7
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+122
-47
vllm/v1/core/sched/utils.py
vllm/v1/core/sched/utils.py
+2
-10
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+6
-3
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+3
-7
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-4
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+22
-15
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+3
-7
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+1
-1
No files found.
vllm/v1/attention/backends/cpu_attn.py
View file @
a3f8d5dd
...
...
@@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata
,
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
,
CrossAttentionSpec
logger
=
init_logger
(
__name__
)
...
...
@@ -50,11 +50,13 @@ class CPUAttentionBackend(AttentionBackend):
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""CPU attention supports decoder and encoder-only attention."""
"""CPU attention supports decoder,
encoder-only and encoder-decoder attention."""
return
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
ENCODER_DECODER
,
)
@
staticmethod
...
...
@@ -136,6 +138,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
self
.
window_size
=
-
1
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
isa
=
_get_attn_isa
(
self
.
dtype
,
self
.
block_size
)
self
.
is_cross_attention
=
isinstance
(
kv_cache_spec
,
CrossAttentionSpec
)
def
build
(
self
,
...
...
@@ -151,7 +154,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
causal
=
common_attn_metadata
.
causal
causal
=
False
if
self
.
is_cross_attention
else
common_attn_metadata
.
causal
sdpa_start_loc
=
query_start_loc
num_decode_tokens
=
0
...
...
@@ -171,9 +174,6 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
query_start_loc
=
query_start_loc
[:
num_decodes
+
1
]
block_table_tensor
=
block_table_tensor
[:
num_decodes
]
sheduler_metadata
=
None
if
causal
:
# for decode batch, use the custom kernel
sheduler_metadata
=
ops
.
cpu_attn_get_scheduler_metadata
(
num_reqs
=
num_reqs
,
num_heads
=
self
.
num_heads
,
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
a3f8d5dd
...
...
@@ -429,6 +429,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
self
.
cache_config
=
vllm_config
.
cache_config
self
.
model_config
=
vllm_config
.
model_config
self
.
attention_config
=
vllm_config
.
attention_config
self
.
_workspace_buffer
=
None
self
.
_prefill_wrapper
:
(
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
|
None
...
...
@@ -563,7 +564,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
self
.
paged_kv_last_page_len_np
=
self
.
paged_kv_last_page_len_cpu
.
numpy
()
if
self
.
head_dim
==
256
and
current_platform
.
is_device_capability
(
100
):
if
self
.
head_dim
==
256
and
current_platform
.
is_device_capability
_family
(
100
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
# head size 256 and block size 16 is not supported on blackwell.
assert
kv_cache_spec
.
block_size
!=
16
,
(
...
...
@@ -779,6 +780,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
cache_dtype
,
self
.
q_data_type
,
is_prefill
=
True
,
force_use_trtllm
=
self
.
attention_config
.
use_trtllm_attention
,
has_sinks
=
self
.
has_sinks
,
has_spec
=
uses_spec_reorder
,
)
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
a3f8d5dd
...
...
@@ -211,7 +211,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_masks
=
torch
.
repeat_interleave
(
spec_sequence_masks
,
query_lens
)
index
=
torch
.
argsort
(
spec_token_masks
)
index
=
torch
.
argsort
(
spec_token_masks
,
stable
=
True
)
num_non_spec_tokens
=
num_prefill_tokens
+
num_decode_tokens
non_spec_token_indx
=
index
[:
num_non_spec_tokens
]
spec_token_indx
=
index
[
num_non_spec_tokens
:]
...
...
vllm/v1/attention/backends/mla/common.py
View file @
a3f8d5dd
...
...
@@ -446,7 +446,7 @@ def use_flashinfer_prefill() -> bool:
and
flashinfer_available
and
not
vllm_config
.
attention_config
.
use_cudnn_prefill
and
not
vllm_config
.
attention_config
.
use_trtllm_ragged_deepseek_prefill
and
current_platform
.
is_device_capability
(
100
)
and
current_platform
.
is_device_capability
_family
(
100
)
)
...
...
@@ -457,7 +457,7 @@ def use_cudnn_prefill() -> bool:
return
(
flashinfer_available
and
vllm_config
.
attention_config
.
use_cudnn_prefill
and
current_platform
.
is_device_capability
(
100
)
and
current_platform
.
is_device_capability
_family
(
100
)
and
has_nvidia_artifactory
()
)
...
...
@@ -470,7 +470,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
return
(
flashinfer_available
and
vllm_config
.
attention_config
.
use_trtllm_ragged_deepseek_prefill
and
current_platform
.
is_device_capability
(
100
)
and
current_platform
.
is_device_capability
_family
(
100
)
)
...
...
@@ -1787,6 +1787,33 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
def
_concat_k_nope_k_pe
(
self
,
k_nope
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Efficiently concatenate k_nope and k_pe tensors along the last dimension.
This function avoids the performance penalty of torch.cat with expanded
non-contiguous tensors by pre-allocating the output and using direct copies.
Args:
k_nope: Tensor of shape [..., nope_dim]
k_pe: Tensor to broadcast and concatenate, typically shape [..., 1, pe_dim]
or [..., pe_dim]
Returns:
Tensor of shape [..., nope_dim + pe_dim]
"""
k
=
torch
.
empty
(
(
*
k_nope
.
shape
[:
-
1
],
k_nope
.
shape
[
-
1
]
+
k_pe
.
shape
[
-
1
]),
dtype
=
k_nope
.
dtype
,
device
=
k_nope
.
device
,
)
# Direct copies with efficient broadcasting
k
[...,
:
k_nope
.
shape
[
-
1
]]
=
k_nope
k
[...,
k_nope
.
shape
[
-
1
]
:]
=
k_pe
return
k
def
_compute_prefill_context
(
self
,
q
:
torch
.
Tensor
,
...
...
@@ -1823,7 +1850,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
self
.
_con
cat
_
k_nope
_
k_pe
(
k_nope
,
k_pe
)
attn_output
,
attn_softmax_lse
=
self
.
_run_prefill_context_chunk
(
prefill
=
prefill_metadata
,
...
...
@@ -1927,7 +1954,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
self
.
_con
cat
_
k_nope
_
k_pe
(
k_nope
,
k_pe
)
attn_output
,
attn_softmax_lse
=
self
.
_run_prefill_context_chunk
(
prefill
=
prefill_metadata
,
...
...
@@ -1976,7 +2003,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
k_nope
,
v
=
kv_nope
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
self
.
_con
cat
_
k_nope
_
k_pe
(
k_nope
,
k_pe
)
output_prefill
=
self
.
_run_prefill_new_tokens
(
prefill
=
attn_metadata
.
prefill
,
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
a3f8d5dd
...
...
@@ -18,7 +18,7 @@ from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache
,
get_mla_metadata
,
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
...
...
@@ -30,13 +30,31 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
reshape_attn_output_for_spec_decode
,
reshape_query_for_spec_decode
,
split_decodes_and_prefills
,
split_prefill_chunks
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.workspace
import
current_workspace_manager
if
TYPE_CHECKING
:
from
vllm.model_executor.models.deepseek_v2
import
Indexer
logger
=
init_logger
(
__name__
)
# For FP8 sparse attention we have two impelementations:
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
# done by treating all tokens as single batch.
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
# the FP8 decode kernel for decode.
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
# prefill kernel requires padding the numer of heads to 128 while the decode does not
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
# batch mode (#2).
MIN_HEADS_FOR_BF16_PREFILL
=
32
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
...
...
@@ -127,19 +145,72 @@ class FlashMLASparseMetadata:
dummy_block_table
:
torch
.
Tensor
cache_lens
:
torch
.
Tensor
fp8_extra_metadata
:
FP8KernelMetadata
|
None
=
None
@
dataclass
class
FP8SeperatePrefillDecode
:
@
dataclass
class
Decode
:
kernel_metadata
:
"FlashMLASparseMetadata.FP8KernelMetadata"
decode_query_len
:
int
# needed for reshape in spec decode
@
dataclass
class
Prefill
:
# Sequence lengths (context + query) for prefill requests
# Shape: [num_prefill_reqs]
seq_lens
:
torch
.
Tensor
# Request ID for each token: -1 for decode tokens, request index
# (0, 1, 2, ...) for prefill tokens.
# Shape: [num_actual_tokens]
request_ids
:
torch
.
Tensor
# Workspace start offsets for all prefill requests
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
# 0-indexed within each chunk. Used to map prefill tokens to workspace
# offsets in convert_logical_index_to_physical_index
workspace_starts
:
torch
.
Tensor
@
dataclass
class
Chunk
:
"""Metadata for a chunk of prefill requests.
Prefill requests may be chunked to fit within the fixed workspace size.
"""
seq_lens
:
torch
.
Tensor
tokens_slice
:
slice
block_table
:
torch
.
Tensor
req_start_idx
:
int
workspace_starts
:
torch
.
Tensor
chunk_tot_seqlen
:
int
chunks
:
list
[
Chunk
]
num_prefills
:
int
=
0
num_decodes
:
int
=
0
num_prefill_tokens
:
int
=
0
num_decode_tokens
:
int
=
0
decode
:
Decode
|
None
=
None
prefill
:
Prefill
|
None
=
None
fp8_extra_metadata
:
FP8SeperatePrefillDecode
|
FP8KernelMetadata
|
None
=
None
fp8_use_mixed_batch
:
bool
=
False
# Kernel with prefill workspace support
@
triton
.
jit
def
_convert_req_index_to_global_index_kernel
(
req_id_ptr
,
# int32 [num_tokens]
block_table_ptr
,
# int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr
,
# int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr
,
# int32 [num_tokens, NUM_TOPK_TOKENS]
prefill_request_id_ptr
,
# int32 [num_tokens], -1 for decode, >=0 for prefill
workspace_starts_ptr
,
# int32 [num_prefill_reqs+1] or nullptr
# shapes (compile-time where possible)
max_num_blocks_per_req
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
# tile width along columns
HAS_PREFILL
:
tl
.
constexpr
,
# strides (in elements)
bt_stride0
,
bt_stride1
,
...
...
@@ -165,7 +236,10 @@ def _convert_req_index_to_global_index_kernel(
# Only token == -1 should propagate as -1
is_invalid_tok
=
tok
<
0
is_prefill
=
False
if
HAS_PREFILL
:
prefill_req_id
=
tl
.
load
(
prefill_request_id_ptr
+
token_id
)
is_prefill
=
prefill_req_id
>=
0
# Compute block id and in-block offset
block_id
=
tok
//
BLOCK_SIZE
inblock_off
=
tok
%
BLOCK_SIZE
...
...
@@ -173,12 +247,18 @@ def _convert_req_index_to_global_index_kernel(
# Guard block_table access
valid_block
=
(
block_id
<
max_num_blocks_per_req
)
&
(
block_id
>=
0
)
bt_ptr
=
block_table_ptr
+
req
*
bt_stride0
+
block_id
*
bt_stride1
base
=
tl
.
load
(
bt_ptr
,
mask
=
valid_block
,
other
=
0
)
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
out_val
=
tl
.
where
(
is_invalid_tok
|
(
~
valid_block
),
-
1
,
base
*
BLOCK_SIZE
+
inblock_off
is_invalid_tok
|=
~
valid_block
base
=
tl
.
load
(
bt_ptr
,
mask
=
valid_block
&
~
is_prefill
,
other
=
0
)
out_val
=
base
*
BLOCK_SIZE
+
inblock_off
# Override with prefill output if prefill is enabled
if
HAS_PREFILL
:
workspace_start
=
tl
.
load
(
workspace_starts_ptr
+
prefill_req_id
,
mask
=
is_prefill
,
other
=
0
)
prefill_out
=
workspace_start
+
tok
out_val
=
tl
.
where
(
is_prefill
,
prefill_out
,
out_val
)
out_val
=
tl
.
where
(
is_invalid_tok
,
-
1
,
out_val
)
# Store results
out_ptr_ij
=
out_ptr
+
token_id
*
out_stride0
+
indice_id
*
out_stride1
...
...
@@ -192,6 +272,9 @@ def triton_convert_req_index_to_global_index(
BLOCK_SIZE
:
int
=
64
,
NUM_TOPK_TOKENS
:
int
=
2048
,
BLOCK_N
:
int
=
128
,
# tile width along columns
HAS_PREFILL_WORKSPACE
:
bool
=
False
,
prefill_workspace_request_ids
:
torch
.
Tensor
|
None
=
None
,
prefill_workspace_starts
:
torch
.
Tensor
|
None
=
None
,
):
"""
out[token_id, indice_id] =
...
...
@@ -202,17 +285,32 @@ def triton_convert_req_index_to_global_index(
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
instead of global cache slots. prefill_workspace_request_ids and
prefill_workspace_starts must be provided.
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
prefill request index (maps to prefill_workspace_starts)
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
starts for each prefill request
"""
assert
req_id
.
dtype
==
torch
.
int32
assert
block_table
.
dtype
==
torch
.
int32
assert
token_indices
.
dtype
==
torch
.
int32
assert
token_indices
.
shape
[
1
]
==
NUM_TOPK_TOKENS
assert
NUM_TOPK_TOKENS
%
BLOCK_N
==
0
,
(
f
"NUM_TOPK_TOKENS (
{
NUM_TOPK_TOKENS
}
) must be divisible byBLOCK_N (
{
BLOCK_N
}
)"
f
"NUM_TOPK_TOKENS (
{
NUM_TOPK_TOKENS
}
) must be divisible by
BLOCK_N (
{
BLOCK_N
}
)"
)
if
HAS_PREFILL_WORKSPACE
:
assert
prefill_workspace_request_ids
is
not
None
assert
prefill_workspace_starts
is
not
None
assert
prefill_workspace_request_ids
.
dtype
==
torch
.
int32
assert
prefill_workspace_starts
.
dtype
==
torch
.
int32
num_tokens
=
req_id
.
shape
[
0
]
num_requests
,
max_num_blocks_per_req
=
block_table
.
shape
max_num_blocks_per_req
=
block_table
.
shape
[
1
]
tiles_per_row
=
NUM_TOPK_TOKENS
//
BLOCK_N
# Ensure contiguous tensors on the same device
...
...
@@ -226,6 +324,13 @@ def triton_convert_req_index_to_global_index(
ti_stride0
,
ti_stride1
=
token_indices_c
.
stride
()
out_stride0
,
out_stride1
=
out
.
stride
()
# Prepare prefill pointers
if
HAS_PREFILL_WORKSPACE
:
assert
prefill_workspace_request_ids
is
not
None
# for mypy
assert
prefill_workspace_starts
is
not
None
# for mypy
assert
prefill_workspace_request_ids
.
is_contiguous
()
assert
prefill_workspace_starts
.
is_contiguous
()
# Exact 2D grid: tokens × column tiles
grid
=
(
num_tokens
,
tiles_per_row
)
...
...
@@ -234,10 +339,13 @@ def triton_convert_req_index_to_global_index(
block_table_c
,
token_indices_c
,
out
,
prefill_workspace_request_ids
,
prefill_workspace_starts
,
# shapes / constexprs
max_num_blocks_per_req
,
BLOCK_SIZE
,
BLOCK_N
,
HAS_PREFILL_WORKSPACE
,
# strides
bt_stride0
,
bt_stride1
,
...
...
@@ -249,7 +357,16 @@ def triton_convert_req_index_to_global_index(
return
out
@
dataclass
def
get_prefill_workspace_size
(
max_model_len
:
int
):
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
# May be tuned later.
# Memory usage: 5 * max_model_len * 576 * 2 bytes
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
# 5 * 163840 * 576 * 2 = ~900 MB
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
return
max_model_len
*
5
class
FlashMLASparseMetadataBuilder
(
AttentionMetadataBuilder
[
FlashMLASparseMetadata
]):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
...
...
@@ -259,29 +376,42 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
layer_names
=
layer_names
cache_config
=
vllm_config
.
cache_config
self
.
kv_cache_spec
=
kv_cache_spec
self
.
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
device
=
device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
self
.
_init_reorder_batch_threshold
(
1
,
supports_spec_as_decode
=
True
)
props
=
torch
.
cuda
.
get_device_properties
(
device
)
sm_count
=
props
.
multi_processor_count
self
.
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
self
.
model_config
)
self
.
topk_tokens
=
vllm_config
.
model_config
.
hf_config
.
index_topk
self
.
use_fp8_kv_cache
=
cache_config
.
cache_dtype
==
"fp8_ds_mla"
self
.
topk_tokens_tensor
=
torch
.
tensor
(
[
self
.
topk_tokens
],
device
=
device
,
dtype
=
torch
.
int32
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
self
.
topk_tokens_tensor
=
torch
.
full
(
(
max_num_seqs
,),
self
.
topk_tokens
,
device
=
device
,
dtype
=
torch
.
int32
)
self
.
max_model_len_tensor
=
torch
.
tensor
(
[
self
.
model_config
.
max_model_len
],
device
=
device
,
dtype
=
torch
.
int32
# Shape: [max_num_seqs], all elements = max_model_len
self
.
max_model_len_tensor
=
torch
.
full
(
(
max_num_seqs
,),
self
.
model_config
.
max_model_len
,
device
=
device
,
dtype
=
torch
.
int32
,
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self
.
dummy_block_table
=
torch
.
empty
(
(
1
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_num_seqs
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Equation taken from FlashMLA/csrc/pybind.cpp
...
...
@@ -290,7 +420,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_num_sm_parts
=
int
(
max
((
sm_count
//
2
)
/
h_k
//
(
cdiv
(
h_q
//
h_k
,
2
*
64
)
*
s_q
),
1
)
)
if
current_platform
.
is_device_capability
(
100
):
if
current_platform
.
is_device_capability
_family
(
100
):
max_num_sm_parts
*=
2
self
.
tile_scheduler_metadata_buffer
=
torch
.
empty
(
# TileSchedulerMetaDataSize = 8
...
...
@@ -299,10 +429,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
dtype
=
torch
.
int32
,
device
=
device
,
)
# Sized for per-request batching (num_decodes + 1)
self
.
num_splits_buffer
=
torch
.
empty
(
# We pack all the tokens into one batch for sparse attention.
# Otherwise, we can exceed the sm of `get_mla_metadata`.
(
2
,),
(
max_num_seqs
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
,
)
...
...
@@ -312,30 +441,171 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
device
=
device
,
)
def
build
(
def
_
build
_fp8_mixed_decode_prefill
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
FlashMLASparseMetadata
:
)
->
"FlashMLASparseMetadata.FP8KernelMetadata"
:
"""Build FP8 metadata treating all tokens as one mixed batch.
This matches main branch's approach and avoids the BF16 prefill kernel
which has head padding overhead when num_heads is small (high TP case).
"""
num_tokens
=
common_attn_metadata
.
num_actual_tokens
starts
=
np
.
asarray
(
common_attn_metadata
.
query_start_loc_cpu
,
dtype
=
np
.
int32
)
seg_lengths
=
np
.
diff
(
starts
)
req_id_per_token
=
np
.
repeat
(
np
.
arange
(
seg_lengths
.
shape
[
0
],
dtype
=
np
.
int32
),
seg_lengths
# Build metadata for all tokens as a single batch
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
=
self
.
topk_tokens_tensor
[:
1
],
# Single batch
num_q_tokens_per_head_k
=
num_tokens
*
self
.
num_heads
,
topk
=
self
.
topk_tokens
,
num_heads_q
=
self
.
num_heads
,
num_heads_k
=
1
,
is_fp8_kvcache
=
True
,
)
# Zero-fill for cudagraphs
self
.
req_id_per_token_buffer
.
fill_
(
0
)
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
num_sm_parts
=
tile_scheduler_metadata
.
size
(
0
)
tile_scheduler_metadata_buffer
=
self
.
tile_scheduler_metadata_buffer
[
:
num_sm_parts
]
tile_scheduler_metadata_buffer
.
copy_
(
tile_scheduler_metadata
)
num_splits_view
=
self
.
num_splits_buffer
[:
2
]
num_splits_view
.
copy_
(
num_splits
)
fp8_metadata
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
scheduler_metadata
=
tile_scheduler_metadata_buffer
,
num_splits
=
num_splits_view
,
cache_lens
=
self
.
max_model_len_tensor
[:
1
],
dummy_block_table
=
self
.
dummy_block_table
[:
1
],
)
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
fp8_extra_metadata
=
None
if
self
.
use_fp8_kv_cache
:
return
fp8_metadata
def
_build_fp8_separate_prefill_decode
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
"FlashMLASparseMetadata.FP8SeperatePrefillDecode"
:
num_tokens
=
common_attn_metadata
.
num_actual_tokens
(
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
)
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
or
1
,
require_uniform
=
True
,
)
)
FP8Meta
=
FlashMLASparseMetadata
.
FP8SeperatePrefillDecode
fp8_metadata
=
FP8Meta
(
num_decodes
=
num_decodes
,
num_prefills
=
num_prefills
,
num_decode_tokens
=
num_decode_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
)
# Extract prefill sequence lengths (context + query, not just query)
# Decode requests come first in the batch, prefill requests follow
prefill_seq_lens
=
None
prefill_request_id
=
None
prefill_workspace_starts
=
None
prefill_chunks
=
None
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if
num_prefills
>
0
:
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens
=
common_attn_metadata
.
seq_lens
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
prefill_seq_lens_cpu
=
seq_lens_cpu
[
num_decodes
:]
prefill_seq_lens
=
seq_lens
[
num_decodes
:]
# Build prefill_request_id: -1 for decode, request index for
# prefill. This enables a single
# convert_logical_index_to_physical_index call for all tokens
prefill_request_id
=
torch
.
full
(
(
num_tokens
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Map prefill tokens to their request IDs (0, 1, 2, ...)
for
req_idx
in
range
(
num_prefills
):
# Get query token range for this prefill request
global_req_idx
=
num_decodes
+
req_idx
req_query_start
=
query_start_loc_cpu
[
global_req_idx
]
req_query_end
=
query_start_loc_cpu
[
global_req_idx
+
1
]
prefill_request_id
[
req_query_start
:
req_query_end
]
=
req_idx
# will be adjusted by chunk loop
prefill_workspace_starts_cpu
=
torch
.
zeros
(
num_prefills
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
prefill_workspace_starts_cpu
[
1
:]
=
torch
.
cumsum
(
prefill_seq_lens_cpu
[:
-
1
],
dim
=
0
)
# populated by non-blocking copy after prefill_workspace_starts_cpu is
# updated by each chunk
prefill_workspace_starts
=
torch
.
empty
(
num_prefills
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Chunk prefill requests to fit within workspace size
max_prefill_buffer_size
=
get_prefill_workspace_size
(
self
.
vllm_config
.
model_config
.
max_model_len
)
chunk_bounds
=
split_prefill_chunks
(
prefill_seq_lens_cpu
,
max_prefill_buffer_size
)
prefill_chunks
=
[]
for
chunk_start
,
chunk_end
in
chunk_bounds
:
# Adjust workspace_starts in-place per chunk to be
# 0-indexed within each chunk
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
# Initial: workspace_starts=[0,10,25,45]
# After: workspace_starts=[0,10,0,20]
# (chunk 0 starts at 0, chunk 1 starts at 0)
offset
=
prefill_workspace_starts_cpu
[
chunk_start
].
item
()
prefill_workspace_starts_cpu
[
chunk_start
:
chunk_end
]
-=
offset
chunk_seq_lens
=
prefill_seq_lens
[
chunk_start
:
chunk_end
]
chunk_tot_seqlen
=
prefill_seq_lens_cpu
[
chunk_start
:
chunk_end
].
sum
()
token_start
=
query_start_loc_cpu
[
num_decodes
+
chunk_start
].
item
()
token_end
=
query_start_loc_cpu
[
num_decodes
+
chunk_end
].
item
()
tokens_slice
=
slice
(
token_start
,
token_end
)
# Create chunk view of gpu tensor
chunk_workspace_starts
=
prefill_workspace_starts
[
chunk_start
:
chunk_end
]
chunk_block_table
=
common_attn_metadata
.
block_table_tensor
[
num_decodes
+
chunk_start
:
num_decodes
+
chunk_end
]
prefill_chunks
.
append
(
FP8Meta
.
Prefill
.
Chunk
(
seq_lens
=
chunk_seq_lens
,
tokens_slice
=
tokens_slice
,
block_table
=
chunk_block_table
,
req_start_idx
=
chunk_start
,
workspace_starts
=
chunk_workspace_starts
,
chunk_tot_seqlen
=
chunk_tot_seqlen
,
)
)
prefill_workspace_starts
.
copy_
(
prefill_workspace_starts_cpu
,
non_blocking
=
True
)
fp8_metadata
.
prefill
=
FP8Meta
.
Prefill
(
seq_lens
=
prefill_seq_lens
,
request_ids
=
prefill_request_id
,
workspace_starts
=
prefill_workspace_starts
,
chunks
=
prefill_chunks
,
)
if
num_decodes
>
0
:
# Compute decode_query_len for spec decode (uniform due to require_uniform)
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
decode_query_len
=
(
query_start_loc_cpu
[
1
]
-
query_start_loc_cpu
[
0
]).
item
()
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
=
self
.
topk_tokens_tensor
,
num_q_tokens_per_head_k
=
num_tok
en
s
*
self
.
num_heads
,
cache_seqlens
=
self
.
topk_tokens_tensor
[:
num_decodes
]
,
num_q_tokens_per_head_k
=
decode_query_l
en
*
self
.
num_heads
,
topk
=
self
.
topk_tokens
,
num_heads_q
=
self
.
num_heads
,
num_heads_k
=
1
,
...
...
@@ -348,33 +618,70 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
:
num_sm_parts
]
tile_scheduler_metadata_buffer
.
copy_
(
tile_scheduler_metadata
)
self
.
num_splits_buffer
.
copy_
(
num_splits
)
# num_splits has size [num_decodes + 1]
num_splits_view
=
self
.
num_splits_buffer
[:
num_decodes
+
1
]
num_splits_view
.
copy_
(
num_splits
)
fp8_extra_metada
ta
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
kernel_me
ta
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
scheduler_metadata
=
tile_scheduler_metadata_buffer
,
num_splits
=
self
.
num_splits_buffer
,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens
=
self
.
max_model_len_tensor
,
dummy_block_table
=
self
.
dummy_block_table
,
num_splits
=
num_splits_view
,
dummy_block_table
=
self
.
dummy_block_table
[:
num_decodes
],
cache_lens
=
self
.
max_model_len_tensor
[:
num_decodes
],
)
fp8_metadata
.
decode
=
FP8Meta
.
Decode
(
kernel_metadata
=
kernel_meta
,
decode_query_len
=
decode_query_len
,
)
return
fp8_metadata
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
FlashMLASparseMetadata
:
cm
=
common_attn_metadata
num_tokens
=
cm
.
num_actual_tokens
starts
=
np
.
asarray
(
cm
.
query_start_loc_cpu
,
dtype
=
np
.
int32
)
seg_lengths
=
np
.
diff
(
starts
)
req_id_per_token
=
np
.
repeat
(
np
.
arange
(
seg_lengths
.
shape
[
0
],
dtype
=
np
.
int32
),
seg_lengths
)
# Zero-fill for cudagraphs
self
.
req_id_per_token_buffer
.
fill_
(
0
)
self
.
req_id_per_token_buffer
[:
req_id_per_token
.
shape
[
0
]].
copy_
(
torch
.
from_numpy
(
req_id_per_token
),
non_blocking
=
True
)
req_id_per_token
=
self
.
req_id_per_token_buffer
[:
num_tokens
]
fp8_extra_metadata
:
(
FlashMLASparseMetadata
.
FP8SeperatePrefillDecode
|
FlashMLASparseMetadata
.
FP8KernelMetadata
|
None
)
=
None
fp8_use_mixed_batch
=
self
.
num_heads
<
MIN_HEADS_FOR_BF16_PREFILL
if
self
.
use_fp8_kv_cache
:
if
fp8_use_mixed_batch
:
fp8_extra_metadata
=
self
.
_build_fp8_mixed_decode_prefill
(
cm
)
else
:
fp8_extra_metadata
=
self
.
_build_fp8_separate_prefill_decode
(
cm
)
metadata
=
FlashMLASparseMetadata
(
num_reqs
=
c
ommon_attn_metadata
.
num_reqs
,
max_query_len
=
c
ommon_attn_metadata
.
max_query_len
,
max_seq_len
=
c
ommon_attn_metadata
.
max_seq_len
,
num_actual_tokens
=
c
ommon_attn_metadata
.
num_actual_tokens
,
query_start_loc
=
c
ommon_attn_metadata
.
query_start_loc
,
slot_mapping
=
c
ommon_attn_metadata
.
slot_mapping
,
block_table
=
c
ommon_attn_metadata
.
block_table_tensor
,
num_reqs
=
c
m
.
num_reqs
,
max_query_len
=
c
m
.
max_query_len
,
max_seq_len
=
c
m
.
max_seq_len
,
num_actual_tokens
=
c
m
.
num_actual_tokens
,
query_start_loc
=
c
m
.
query_start_loc
,
slot_mapping
=
c
m
.
slot_mapping
,
block_table
=
c
m
.
block_table_tensor
,
req_id_per_token
=
req_id_per_token
,
block_size
=
self
.
kv_cache_spec
.
block_size
,
topk_tokens
=
self
.
topk_tokens
,
fp8_extra_metadata
=
fp8_extra_metadata
,
fp8_use_mixed_batch
=
fp8_use_mixed_batch
,
)
return
metadata
...
...
@@ -412,7 +719,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
self
.
softmax_scale
=
scale
assert
indexer
is
not
None
self
.
topk_indices_buffer
=
indexer
.
topk_indices_buffer
self
.
padding
=
128
if
current_platform
.
is_device_capability
(
100
)
else
64
self
.
padding
=
128
if
current_platform
.
is_device_capability_family
(
100
)
else
64
if
kv_cache_dtype
==
"fp8_ds_mla"
:
# Reserve workspace during initialization
vllm_config
=
get_current_vllm_config
()
assert
vllm_config
is
not
None
and
vllm_config
.
model_config
is
not
None
prefill_workspace_size
=
get_prefill_workspace_size
(
vllm_config
.
model_config
.
max_model_len
)
self
.
prefill_workspace_shape
=
(
prefill_workspace_size
,
head_size
)
(
self
.
prefill_bf16_workspace
,)
=
(
current_workspace_manager
().
get_simultaneous
(
(
self
.
prefill_workspace_shape
,
torch
.
bfloat16
)
)
)
def
_forward_bf16_kv
(
self
,
...
...
@@ -420,6 +741,184 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
,
)
->
torch
.
Tensor
:
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices
=
triton_convert_req_index_to_global_index
(
attn_metadata
.
req_id_per_token
,
attn_metadata
.
block_table
,
topk_indices
,
BLOCK_SIZE
=
attn_metadata
.
block_size
,
NUM_TOPK_TOKENS
=
topk_indices
.
shape
[
1
],
)
return
self
.
_bf16_flash_mla_kernel
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
)
def
_forward_fp8_kv_separate_prefill_decode
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
,
)
->
torch
.
Tensor
:
fp8_metadata
=
attn_metadata
.
fp8_extra_metadata
assert
isinstance
(
fp8_metadata
,
FlashMLASparseMetadata
.
FP8SeperatePrefillDecode
)
num_decodes
=
fp8_metadata
.
num_decodes
prefill_request_ids
=
None
prefill_workspace_starts
=
None
has_prefill_workspace
=
False
if
fp8_metadata
.
prefill
is
not
None
:
prefill_request_ids
=
fp8_metadata
.
prefill
.
request_ids
prefill_workspace_starts
=
fp8_metadata
.
prefill
.
workspace_starts
has_prefill_workspace
=
True
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
# For BF16 cache: always use global cache slots (no workspace)
# prefill_workspace_starts has been adjusted in-place per chunk so
# prefill indices automatically come out chunk-local
topk_indices
=
triton_convert_req_index_to_global_index
(
attn_metadata
.
req_id_per_token
,
attn_metadata
.
block_table
,
topk_indices
,
BLOCK_SIZE
=
attn_metadata
.
block_size
,
NUM_TOPK_TOKENS
=
topk_indices
.
shape
[
1
],
HAS_PREFILL_WORKSPACE
=
has_prefill_workspace
,
prefill_workspace_request_ids
=
prefill_request_ids
,
prefill_workspace_starts
=
prefill_workspace_starts
,
)
fp8_metadata
=
attn_metadata
.
fp8_extra_metadata
assert
isinstance
(
fp8_metadata
,
FlashMLASparseMetadata
.
FP8SeperatePrefillDecode
)
def
_fp8_decode
(
q
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q
=
reshape_query_for_spec_decode
(
q
,
num_decodes
)
seq_len
=
q
.
shape
[
1
]
# Reshape topk_indices: (num_decode_tokens, topk)
# -> (num_decodes, seq_len, topk)
topk_indices
=
topk_indices
.
view
(
num_decodes
,
seq_len
,
-
1
)
assert
fp8_metadata
.
decode
is
not
None
attn_out
,
_
=
self
.
_fp8_flash_mla_kernel
(
q
=
q
,
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
,
topk_indices
=
topk_indices
,
kernel_metadata
=
fp8_metadata
.
decode
.
kernel_metadata
,
)
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
# -> (num_decode_tokens, num_heads, head_dim_v)
return
reshape_attn_output_for_spec_decode
(
attn_out
)
num_decode_tokens
=
fp8_metadata
.
num_decode_tokens
num_prefill_tokens
=
fp8_metadata
.
num_prefill_tokens
# Pure decode: direct call without allocation
if
num_decode_tokens
>
0
and
num_prefill_tokens
==
0
:
assert
fp8_metadata
.
decode
is
not
None
attn_out
=
_fp8_decode
(
q
,
topk_indices
)
else
:
# Mixed or pure prefill: allocate output tensor
attn_out
=
q
.
new_empty
(
(
attn_metadata
.
num_actual_tokens
,
self
.
num_heads
,
self
.
kv_lora_rank
),
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
if
num_decode_tokens
>
0
:
attn_out
[:
num_decode_tokens
]
=
_fp8_decode
(
q
[:
num_decode_tokens
],
topk_indices
[:
num_decode_tokens
]
)
assert
fp8_metadata
.
prefill
is
not
None
for
chunk
in
fp8_metadata
.
prefill
.
chunks
:
chunk_workspace
=
self
.
prefill_bf16_workspace
[:
chunk
.
chunk_tot_seqlen
]
ops
.
cp_gather_and_upconvert_fp8_kv_cache
(
kv_c_and_k_pe_cache
,
chunk_workspace
,
chunk
.
block_table
,
chunk
.
seq_lens
,
chunk
.
workspace_starts
,
len
(
chunk
.
block_table
),
)
chunk_q
=
q
[
chunk
.
tokens_slice
]
chunk_topk_indices_workspace
=
topk_indices
[
chunk
.
tokens_slice
]
attn_out
[
chunk
.
tokens_slice
]
=
self
.
_bf16_flash_mla_kernel
(
chunk_q
,
chunk_workspace
,
chunk_topk_indices_workspace
,
)
return
attn_out
def
_forward_fp8_kv_mixed_batch
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
,
)
->
torch
.
Tensor
:
"""Mixed batch FP8 forward path that treats all tokens as one batch.
This is equivalent to main branch's approach and avoids the BF16
prefill kernel which has head padding overhead when num_heads is small.
Used when use_mixed_batch is True.
"""
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices
=
triton_convert_req_index_to_global_index
(
attn_metadata
.
req_id_per_token
,
attn_metadata
.
block_table
,
topk_indices
,
BLOCK_SIZE
=
attn_metadata
.
block_size
,
NUM_TOPK_TOKENS
=
topk_indices
.
shape
[
1
],
)
assert
attn_metadata
.
fp8_extra_metadata
is
not
None
assert
isinstance
(
attn_metadata
.
fp8_extra_metadata
,
FlashMLASparseMetadata
.
FP8KernelMetadata
)
fp8_metadata
=
attn_metadata
.
fp8_extra_metadata
_attn_out
,
_
=
self
.
_fp8_flash_mla_kernel
(
q
=
q
.
unsqueeze
(
0
),
# unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
,
topk_indices
=
topk_indices
.
unsqueeze
(
0
),
# (T, topk) -> (1, T, topk)
kernel_metadata
=
fp8_metadata
,
)
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
return
_attn_out
.
squeeze
(
0
)
def
_fp8_flash_mla_kernel
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
kernel_metadata
:
FlashMLASparseMetadata
.
FP8KernelMetadata
,
)
->
torch
.
Tensor
:
return
flash_mla_with_kvcache
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
uint8
).
unsqueeze
(
-
2
),
block_table
=
kernel_metadata
.
dummy_block_table
,
head_dim_v
=
512
,
cache_seqlens
=
kernel_metadata
.
cache_lens
,
tile_scheduler_metadata
=
kernel_metadata
.
scheduler_metadata
,
num_splits
=
kernel_metadata
.
num_splits
,
is_fp8_kvcache
=
True
,
indices
=
topk_indices
,
softmax_scale
=
self
.
softmax_scale
,
)
def
_bf16_flash_mla_kernel
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_tokens
=
q
.
shape
[
0
]
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
...
...
@@ -445,31 +944,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
output
=
output
[:,
:
self
.
num_heads
,
:]
return
output
def
_forward_fp8_kv
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
,
)
->
torch
.
Tensor
:
assert
attn_metadata
.
fp8_extra_metadata
is
not
None
extra_metadata
=
attn_metadata
.
fp8_extra_metadata
_attn_out
,
_
=
flash_mla_with_kvcache
(
q
=
q
.
unsqueeze
(
0
),
# unsqueeze to add batch_dim
k_cache
=
kv_c_and_k_pe_cache
.
view
(
torch
.
uint8
).
unsqueeze
(
-
2
),
block_table
=
extra_metadata
.
dummy_block_table
,
head_dim_v
=
512
,
cache_seqlens
=
extra_metadata
.
cache_lens
,
tile_scheduler_metadata
=
extra_metadata
.
scheduler_metadata
,
num_splits
=
extra_metadata
.
num_splits
,
is_fp8_kvcache
=
True
,
indices
=
topk_indices
.
unsqueeze
(
0
),
# unsqueeze to add batch_dim
softmax_scale
=
self
.
softmax_scale
,
)
return
_attn_out
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
@@ -477,7 +951,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLASparseMetadata
,
attn_metadata
:
FlashMLASparseMetadata
|
None
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -493,6 +967,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
)
if
attn_metadata
is
None
:
# Dummy run - no need to allocate buffers
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
...
...
@@ -505,6 +980,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q
=
q
[:
num_actual_toks
,
...]
k_c_normed
=
k_c_normed
[:
num_actual_toks
,
...]
k_pe
=
k_pe
[:
num_actual_toks
,
...]
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
...
...
@@ -514,16 +990,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# Convert from (N, B, L) to (B, N, L)
ql_nope
=
ql_nope
.
transpose
(
0
,
1
)
topk_indices
=
self
.
topk_indices_buffer
[:
num_actual_toks
]
# TODO: handle index / kv_cache correctly
topk_indices_global
=
triton_convert_req_index_to_global_index
(
attn_metadata
.
req_id_per_token
,
attn_metadata
.
block_table
,
topk_indices
,
BLOCK_SIZE
=
attn_metadata
.
block_size
,
NUM_TOPK_TOKENS
=
attn_metadata
.
topk_tokens
,
)
use_fp8_cache
=
self
.
kv_cache_dtype
==
"fp8_ds_mla"
q
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
...
...
@@ -538,13 +1005,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
scale
=
layer
.
_k_scale
,
)
if
self
.
kv_cache_dtype
!=
"fp8_ds_mla"
:
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_cache
,
topk_indices_global
,
attn_metadata
if
not
use_fp8_cache
:
attn_out
=
self
.
_forward_bf16_kv
(
q
,
kv_cache
,
topk_indices
,
attn_metadata
)
elif
attn_metadata
.
fp8_use_mixed_batch
:
attn_out
=
self
.
_forward_fp8_kv_mixed_batch
(
q
,
kv_cache
,
topk_indices
,
attn_metadata
)
else
:
attn_out
=
self
.
_forward_fp8_kv
(
q
,
kv_cache
,
topk_indices
_global
,
attn_metadata
attn_out
=
self
.
_forward_fp8_kv
_separate_prefill_decode
(
q
,
kv_cache
,
topk_indices
,
attn_metadata
)
self
.
_v_up_proj
(
attn_out
,
out
=
output
[:
num_actual_toks
])
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
a3f8d5dd
...
...
@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
split_decodes_and_prefills
,
split_prefill_chunks
,
)
logger
=
init_logger
(
__name__
)
...
...
@@ -176,40 +177,15 @@ def kv_spans_from_batches(
def
get_max_prefill_buffer_size
(
vllm_config
:
VllmConfig
):
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# May be tuned later.
return
max_model_len
*
2
def
split_prefill_chunks
(
seq_lens_cpu
:
torch
.
Tensor
,
max_prefill_buffer_size
:
int
,
reqs_start
:
int
)
->
list
[
tuple
[
int
,
int
]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids
=
[]
total_seq_lens
=
0
for
i
in
range
(
reqs_start
,
len
(
seq_lens_cpu
)):
cur_seq_len
=
seq_lens_cpu
[
i
].
item
()
assert
cur_seq_len
<=
max_prefill_buffer_size
total_seq_lens
+=
cur_seq_len
if
total_seq_lens
>
max_prefill_buffer_size
:
chunk_seq_ids
.
append
((
reqs_start
,
i
))
reqs_start
=
i
total_seq_lens
=
cur_seq_len
if
total_seq_lens
>
0
:
chunk_seq_ids
.
append
((
reqs_start
,
len
(
seq_lens_cpu
)))
return
chunk_seq_ids
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
# within the flashmla_sparse workspace.
# For DeepSeek-V3.2, the max_model_len is 163840.
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
return
max_model_len
*
40
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
...
...
@@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata
=
None
if
num_prefills
>
0
:
chunk_seq_ids
=
split_prefill_chunks
(
common_attn_metadata
.
seq_lens_cpu
,
common_attn_metadata
.
seq_lens_cpu
[
num_decodes
:]
,
self
.
max_prefill_buffer_size
,
num_decodes
,
request_offset
=
num_decodes
,
)
chunks
=
[
self
.
build_one_prefill_chunk
(
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
a3f8d5dd
...
...
@@ -17,7 +17,7 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash
,
)
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
...
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionMetadataBuilder
,
...
...
@@ -36,6 +37,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger
=
init_logger
(
__name__
)
# constants
MIN_LAUNCH_GRID_SIZE_2D
=
128
# Minimum launch grid size of 2D kernel
NUM_PAR_SOFTMAX_SEGMENTS
=
16
# Number of parallel tiled softmax segments
@
dataclass
class
TritonAttentionMetadata
:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
...
...
@@ -54,6 +60,12 @@ class TritonAttentionMetadata:
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
seq_threshold_3D
:
int
num_par_softmax_segments
:
int
softmax_segm_output
:
torch
.
Tensor
softmax_segm_max
:
torch
.
Tensor
softmax_segm_expsum
:
torch
.
Tensor
# For cascade attention.
use_cascade
:
bool
common_prefix_len
:
int
...
...
@@ -87,6 +99,60 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
self
.
num_heads_kv
=
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
self
.
headdim
=
model_config
.
get_head_size
()
# Check if CUDA Graphs are enabled for decode
self
.
decode_cudagraph_enabled
=
(
self
.
vllm_config
.
compilation_config
.
cudagraph_mode
in
(
CUDAGraphMode
.
FULL_AND_PIECEWISE
,
CUDAGraphMode
.
FULL_DECODE_ONLY
,
CUDAGraphMode
.
FULL
,
)
)
# The launch grid for the 2D kernel is defined as (num_q_blocks, num_heads_kv).
# A lower bound for num_q_blocks is the number of sequences.
# To ensure the minimum launch grid size is achieved, the number of sequences
# must be at least equal to the threshold below.
# If this threshold is not reached (i.e., the batch size is not large enough),
# the 3D kernel will be selected instead.
self
.
seq_threshold_3D
=
MIN_LAUNCH_GRID_SIZE_2D
//
self
.
num_heads_kv
# Modify the threshold if needed.
if
self
.
decode_cudagraph_enabled
:
capture_sizes
=
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
assert
capture_sizes
,
"CUDA Graphs enabled but no capture sizes specified."
# Select the CUDA Graph capture size closest to self.seq_threshold_3D
# as threshold. This ensures that each captured graph covers the
# correct execution path.
self
.
seq_threshold_3D
=
min
(
capture_sizes
,
key
=
lambda
x
:
abs
(
x
-
self
.
seq_threshold_3D
),
)
self
.
num_par_softmax_segments
=
NUM_PAR_SOFTMAX_SEGMENTS
headdim_padded
=
next_power_of_2
(
self
.
headdim
)
self
.
softmax_segm_output
=
torch
.
empty
(
(
self
.
seq_threshold_3D
,
self
.
num_heads_q
,
self
.
num_par_softmax_segments
,
headdim_padded
,
),
dtype
=
torch
.
float32
,
device
=
device
,
)
self
.
softmax_segm_max
=
torch
.
empty
(
(
self
.
seq_threshold_3D
,
self
.
num_heads_q
,
self
.
num_par_softmax_segments
),
dtype
=
torch
.
float32
,
device
=
device
,
)
self
.
softmax_segm_expsum
=
torch
.
empty
(
(
self
.
seq_threshold_3D
,
self
.
num_heads_q
,
self
.
num_par_softmax_segments
),
dtype
=
torch
.
float32
,
device
=
device
,
)
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
TritonAttentionMetadata
:
...
...
@@ -143,6 +209,11 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
seq_threshold_3D
=
self
.
seq_threshold_3D
,
num_par_softmax_segments
=
self
.
num_par_softmax_segments
,
softmax_segm_output
=
self
.
softmax_segm_output
,
softmax_segm_max
=
self
.
softmax_segm_max
,
softmax_segm_expsum
=
self
.
softmax_segm_expsum
,
)
return
attn_metadata
...
...
@@ -349,6 +420,12 @@ class TritonAttentionImpl(AttentionImpl):
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
seq_threshold_3D
=
attn_metadata
.
seq_threshold_3D
num_par_softmax_segments
=
attn_metadata
.
num_par_softmax_segments
softmax_segm_output
=
attn_metadata
.
softmax_segm_output
softmax_segm_max
=
attn_metadata
.
softmax_segm_max
softmax_segm_expsum
=
attn_metadata
.
softmax_segm_expsum
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
2
])
unified_attention
(
...
...
@@ -369,6 +446,11 @@ class TritonAttentionImpl(AttentionImpl):
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
seq_threshold_3D
=
seq_threshold_3D
,
num_par_softmax_segments
=
num_par_softmax_segments
,
softmax_segm_output
=
softmax_segm_output
,
softmax_segm_max
=
softmax_segm_max
,
softmax_segm_expsum
=
softmax_segm_expsum
,
sinks
=
self
.
sinks
,
output_scale
=
output_scale
,
)
...
...
vllm/v1/attention/backends/utils.py
View file @
a3f8d5dd
...
...
@@ -937,6 +937,33 @@ def split_decodes_and_prefills(
return
(
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
)
def
split_prefill_chunks
(
seq_lens_cpu
:
torch
.
Tensor
,
workspace_size
:
int
,
request_offset
:
int
=
0
)
->
list
[
tuple
[
int
,
int
]]:
"""
Split the prefill requests into chunks such that the total sequence length
of each chunk is less than or equal to the workspace size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
workspace_size: The maximum workspace size (in tokens) per chunk.
request_offset: The offset to add to the request indices.
Returns:
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
"""
chunk_bounds
=
[]
i
,
n
=
0
,
len
(
seq_lens_cpu
)
assert
torch
.
all
(
seq_lens_cpu
<=
workspace_size
).
item
()
while
i
<
n
:
start
,
chunk_total
=
i
,
0
while
i
<
n
and
(
chunk_total
+
(
s
:
=
seq_lens_cpu
[
i
].
item
()))
<=
workspace_size
:
chunk_total
+=
s
i
+=
1
chunk_bounds
.
append
((
start
+
request_offset
,
i
+
request_offset
))
return
chunk_bounds
def
reorder_batch_to_split_decodes_and_prefills
(
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
,
...
...
vllm/v1/core/block_pool.py
View file @
a3f8d5dd
...
...
@@ -397,6 +397,25 @@ class BlockPool:
[
block
for
block
in
blocks_list
if
block
.
ref_cnt
==
0
and
not
block
.
is_null
]
)
def
evict_blocks
(
self
,
block_ids
:
set
[
int
])
->
None
:
"""evict blocks from the prefix cache by their block IDs.
only evicts blocks that are currently cached (have a hash). blocks
with ref_cnt > 0 are not freed from the block pool, only evicted
from the prefix cache hash table.
Args:
block_ids: Set of block IDs to evict from cache.
"""
for
block_id
in
block_ids
:
assert
block_id
<
len
(
self
.
blocks
),
(
f
"Invalid block_id
{
block_id
}
>=
{
len
(
self
.
blocks
)
}
. "
f
"This indicates a bug in the KV connector - workers should "
f
"only report block IDs that were allocated by the scheduler."
)
block
=
self
.
blocks
[
block_id
]
self
.
_maybe_evict_cached_block
(
block
)
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
a3f8d5dd
...
...
@@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
token
s from the input sequence.
encoder embedding
s from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder
token
s.
num_free_slots: Current available cache capacity in encoder
token
s.
cache_size: Total cache capacity in encoder
embedding
s.
num_free_slots: Current available cache capacity in encoder
embedding
s.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder
token
s).
evicting entries with zero references (in encoder
embedding
s).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_
token
s) representing entries
freeable: List of tuples (mm_hash, num_
encoder_embed
s) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
...
...
@@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self
.
cached
:
dict
[
str
,
set
[
str
]]
=
{}
# mm_hash of mm_data => num_encoder_
token
s of the mm_data
# mm_hash of mm_data => num_encoder_
embed
s of the mm_data
self
.
freeable
:
OrderedDict
[
str
,
int
]
=
OrderedDict
()
self
.
freed
:
list
[
str
]
=
[]
...
...
@@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request
if
not
self
.
cached
[
mm_hash
]:
num_
token
s
=
self
.
freeable
.
pop
(
mm_hash
)
self
.
num_freeable_slots
-=
num_
token
s
num_
encoder_embed
s
=
self
.
freeable
.
pop
(
mm_hash
)
self
.
num_freeable_slots
-=
num_
encoder_embed
s
self
.
cached
[
mm_hash
].
add
(
request
.
request_id
)
return
True
...
...
@@ -104,7 +110,7 @@ class EncoderCacheManager:
request
:
Request
,
input_id
:
int
,
encoder_compute_budget
:
int
,
num_
token
s_to_schedule
:
int
,
num_
embed
s_to_schedule
:
int
,
)
->
bool
:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
...
...
@@ -121,9 +127,9 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder
token
s allowed to be
encoder_compute_budget: Number of encoder
embedding
s allowed to be
computed when this method is invoked.
num_
token
s_to_schedule: Number of
token
s already scheduled to be
num_
embed
s_to_schedule: Number of
encoder embedding
s already scheduled to be
allocated with cache space when this method is invoked.
Returns:
...
...
@@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_
token
s
=
request
.
get_num_encoder_
token
s
(
input_id
)
num_
embed
s
=
request
.
get_num_encoder_
embed
s
(
input_id
)
# Not enough compute budget
if
num_
token
s
>
encoder_compute_budget
:
if
num_
embed
s
>
encoder_compute_budget
:
return
False
num_
token
s
+=
num_
token
s_to_schedule
num_
embed
s
+=
num_
embed
s_to_schedule
# Enough free slots
if
num_
token
s
<=
self
.
num_free_slots
:
if
num_
embed
s
<=
self
.
num_free_slots
:
return
True
# Not enough reclaimable slots
if
num_
token
s
>
self
.
num_freeable_slots
:
if
num_
embed
s
>
self
.
num_freeable_slots
:
return
False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while
num_
token
s
>
self
.
num_free_slots
:
mm_hash
,
num_free_
token
=
self
.
freeable
.
popitem
(
last
=
False
)
while
num_
embed
s
>
self
.
num_free_slots
:
mm_hash
,
num_free_
embeds
=
self
.
freeable
.
popitem
(
last
=
False
)
del
self
.
cached
[
mm_hash
]
self
.
freed
.
append
(
mm_hash
)
self
.
num_free_slots
+=
num_free_
token
self
.
num_free_slots
+=
num_free_
embeds
return
True
def
allocate
(
self
,
request
:
Request
,
input_id
:
int
)
->
None
:
...
...
@@ -176,16 +182,16 @@ class EncoderCacheManager:
if
mm_hash
not
in
self
.
cached
:
self
.
cached
[
mm_hash
]
=
set
()
num_encoder_
token
s
=
request
.
get_num_encoder_
token
s
(
input_id
)
num_encoder_
embed
s
=
request
.
get_num_encoder_
embed
s
(
input_id
)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert
self
.
num_free_slots
>=
num_encoder_
token
s
assert
self
.
num_freeable_slots
>=
num_encoder_
token
s
assert
self
.
num_free_slots
>=
num_encoder_
embed
s
assert
self
.
num_freeable_slots
>=
num_encoder_
embed
s
self
.
cached
[
mm_hash
].
add
(
request_id
)
self
.
num_free_slots
-=
num_encoder_
token
s
self
.
num_freeable_slots
-=
num_encoder_
token
s
self
.
num_free_slots
-=
num_encoder_
embed
s
self
.
num_freeable_slots
-=
num_encoder_
embed
s
def
get_cached_input_ids
(
self
,
request
:
Request
)
->
set
[
int
]:
"""Get all cached multimodal input IDs for a request.
...
...
@@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder
token
s for that input.
increased by the number of encoder
embedding
s for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
...
...
@@ -218,9 +224,9 @@ class EncoderCacheManager:
return
self
.
cached
[
mm_hash
].
discard
(
req_id
)
if
not
self
.
cached
[
mm_hash
]:
num_
token
s
=
request
.
get_num_encoder_
token
s
(
input_id
)
self
.
freeable
[
mm_hash
]
=
num_
token
s
self
.
num_freeable_slots
+=
num_
token
s
num_
encoder_embed
s
=
request
.
get_num_encoder_
embed
s
(
input_id
)
self
.
freeable
[
mm_hash
]
=
num_
encoder_embed
s
self
.
num_freeable_slots
+=
num_
encoder_embed
s
def
free
(
self
,
request
:
Request
)
->
None
:
"""Free all encoder input cache reference held by *request*.
...
...
@@ -341,3 +347,56 @@ def compute_mm_encoder_budget(
)
return
encoder_compute_budget
,
encoder_cache_size
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
# use the manager for scheduling purposes. Encoder-decoder models will eventually
# utilize the cache and this class will fold into EncoderCacheManager, as
# differences with MM models shrink.
class
EncoderDecoderCacheManager
(
EncoderCacheManager
):
def
__init__
(
self
,
cache_size
:
int
):
self
.
cache_size
=
cache_size
self
.
num_free_slots
=
cache_size
self
.
freed
:
list
[
str
]
=
[]
def
check_and_update_cache
(
self
,
request
:
Request
,
input_id
:
int
)
->
bool
:
return
False
def
can_allocate
(
self
,
request
:
Request
,
input_id
:
int
,
encoder_compute_budget
:
int
,
num_embeds_to_schedule
:
int
,
)
->
bool
:
num_encoder_embeds
=
request
.
get_num_encoder_embeds
(
input_id
)
# Not enough compute budget
if
num_encoder_embeds
>
encoder_compute_budget
:
return
False
num_encoder_embeds
+=
num_embeds_to_schedule
# Enough free slots
return
num_encoder_embeds
<=
self
.
num_free_slots
def
allocate
(
self
,
request
:
Request
,
input_id
:
int
)
->
None
:
num_encoder_embeds
=
request
.
get_num_encoder_embeds
(
input_id
)
self
.
num_free_slots
-=
num_encoder_embeds
mm_hash
=
request
.
mm_features
[
input_id
].
identifier
self
.
freed
.
append
(
mm_hash
)
def
free
(
self
,
request
:
Request
)
->
None
:
for
input_id
in
range
(
len
(
request
.
mm_features
)):
self
.
free_encoder_input
(
request
,
input_id
)
def
get_cached_input_ids
(
self
,
request
:
Request
)
->
set
[
int
]:
return
set
(
range
(
len
(
request
.
mm_features
)))
def
get_freed_mm_hashes
(
self
)
->
list
[
str
]:
freed
=
self
.
freed
self
.
freed
=
[]
return
freed
def
free_encoder_input
(
self
,
request
:
Request
,
input_id
:
int
)
->
None
:
num_encoder_embeds
=
request
.
get_num_encoder_embeds
(
input_id
)
self
.
num_free_slots
+=
num_encoder_embeds
vllm/v1/core/kv_cache_manager.py
View file @
a3f8d5dd
...
...
@@ -333,6 +333,14 @@ class KVCacheManager:
"""
self
.
coordinator
.
free
(
request
.
request_id
)
def
evict_blocks
(
self
,
block_ids
:
set
[
int
])
->
None
:
"""evict blocks from the prefix cache by their block IDs.
Args:
block_ids: Set of block IDs to evict from cache.
"""
self
.
block_pool
.
evict_blocks
(
block_ids
)
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,
...
...
vllm/v1/core/kv_cache_utils.py
View file @
a3f8d5dd
...
...
@@ -687,7 +687,9 @@ def check_enough_kv_cache_memory(
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
"initializing the engine. "
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
"for more details."
)
max_model_len
=
vllm_config
.
model_config
.
max_model_len
...
...
@@ -711,8 +713,10 @@ def check_enough_kv_cache_memory(
f
"cache is needed, which is larger than the available KV cache "
f
"memory (
{
available_memory
/
GiB_bytes
:.
2
f
}
GiB). "
f
"
{
estimated_msg
}
"
f
"Try increasing `gpu_memory_utilization` or decreasing "
f
"`max_model_len` when initializing the engine."
f
"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
f
"when initializing the engine. "
f
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
f
"for more details."
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
a3f8d5dd
...
...
@@ -27,6 +27,7 @@ from vllm.logger import init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
EncoderDecoderCacheManager
,
compute_encoder_budget
,
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
,
KVCacheManager
...
...
@@ -106,6 +107,7 @@ class Scheduler(SchedulerInterface):
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self
.
connector
=
None
self
.
connector_prefix_cache_stats
:
PrefixCacheStats
|
None
=
None
self
.
recompute_kv_load_failures
=
True
if
self
.
vllm_config
.
kv_transfer_config
is
not
None
:
assert
not
self
.
is_encoder_decoder
,
(
"Encoder-decoder models are not currently supported with KV connectors"
...
...
@@ -117,6 +119,10 @@ class Scheduler(SchedulerInterface):
)
if
self
.
log_stats
:
self
.
connector_prefix_cache_stats
=
PrefixCacheStats
()
kv_load_failure_policy
=
(
self
.
vllm_config
.
kv_transfer_config
.
kv_load_failure_policy
)
self
.
recompute_kv_load_failures
=
kv_load_failure_policy
==
"recompute"
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
...
...
@@ -176,7 +182,11 @@ class Scheduler(SchedulerInterface):
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
self
.
encoder_cache_manager
=
(
EncoderDecoderCacheManager
(
cache_size
=
encoder_cache_size
)
if
self
.
is_encoder_decoder
else
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
)
speculative_config
=
vllm_config
.
speculative_config
self
.
use_eagle
=
False
...
...
@@ -339,11 +349,11 @@ class Scheduler(SchedulerInterface):
if
preempted_encoder_inputs
:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_
token
s_to_restore
=
sum
(
preempted_req
.
get_num_encoder_
token
s
(
i
)
num_
embed
s_to_restore
=
sum
(
preempted_req
.
get_num_encoder_
embed
s
(
i
)
for
i
in
preempted_encoder_inputs
)
encoder_compute_budget
+=
num_
token
s_to_restore
encoder_compute_budget
+=
num_
embed
s_to_restore
req_index
-=
1
else
:
preempted_req
=
self
.
running
.
pop
()
...
...
@@ -901,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule
=
set
()
num_
token
s_to_schedule
=
0
num_
embed
s_to_schedule
=
0
for
i
,
mm_feature
in
enumerate
(
mm_features
):
start_pos
=
mm_feature
.
mm_position
.
offset
num_encoder_tokens
=
mm_feature
.
mm_position
.
length
num_encoder_embeds
=
mm_feature
.
mm_position
.
get_num_embeds
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...
...
@@ -960,9 +971,8 @@ class Scheduler(SchedulerInterface):
):
num_new_tokens
=
start_pos
-
num_computed_tokens
break
if
not
self
.
encoder_cache_manager
.
can_allocate
(
request
,
i
,
encoder_compute_budget
,
num_
token
s_to_schedule
request
,
i
,
encoder_compute_budget
,
num_
embed
s_to_schedule
):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
...
...
@@ -982,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens
=
0
break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel
=
max
(
0
,
num_computed_tokens
-
start_pos
)
end_idx_rel
=
min
(
num_encoder_tokens
,
num_computed_tokens
+
num_new_tokens
-
start_pos
)
curr_embeds_start
,
curr_embeds_end
=
(
mm_feature
.
mm_position
.
get_embeds_indices_in_range
(
start_idx_rel
,
end_idx_rel
,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if
curr_embeds_end
-
curr_embeds_start
==
0
:
continue
if
self
.
ec_connector
is
not
None
and
remote_cache_has_item
[
i
]:
mm_hashes_to_schedule
.
add
(
request
.
mm_features
[
i
].
identifier
)
external_load_encoder_input
.
append
(
i
)
num_
token
s_to_schedule
+=
num_encoder_
token
s
num_
embed
s_to_schedule
+=
num_encoder_
embed
s
continue
num_
token
s_to_schedule
+=
num_encoder_
token
s
encoder_compute_budget
-=
num_encoder_
token
s
num_
embed
s_to_schedule
+=
num_encoder_
embed
s
encoder_compute_budget
-=
num_encoder_
embed
s
mm_hashes_to_schedule
.
add
(
request
.
mm_features
[
i
].
identifier
)
encoder_inputs_to_schedule
.
append
(
i
)
...
...
@@ -1066,7 +1093,7 @@ class Scheduler(SchedulerInterface):
for
req_id
,
num_tokens_scheduled
in
num_scheduled_tokens
.
items
():
assert
num_tokens_scheduled
>
0
if
failed_kv_load_req_ids
and
req_id
in
failed_kv_load_req_ids
:
#
S
kip
requests that were recovered
from KV load failure
#
s
kip
failed or rescheduled requests
from KV load failure
continue
request
=
self
.
requests
.
get
(
req_id
)
if
request
is
None
:
...
...
@@ -1107,6 +1134,7 @@ class Scheduler(SchedulerInterface):
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
pooler_output
=
pooler_outputs
[
req_index
]
if
pooler_outputs
else
None
kv_transfer_params
=
None
status_before_stop
=
request
.
status
...
...
@@ -1115,12 +1143,10 @@ class Scheduler(SchedulerInterface):
new_token_ids
,
stopped
=
self
.
_update_request_with_output
(
request
,
new_token_ids
)
# Stop checking for pooler models.
pooler_output
=
None
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
self
.
max_model_len
,
pooler_output
)
elif
request
.
pooling_params
and
pooler_output
is
not
None
:
# Pooling stops as soon as there is output.
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
stopped
=
True
if
stopped
:
kv_transfer_params
=
self
.
_free_request
(
request
)
...
...
@@ -1177,6 +1203,21 @@ class Scheduler(SchedulerInterface):
# This is a rare case and unlikely to impact performance.
self
.
waiting
.
remove_requests
(
stopped_preempted_reqs
)
if
failed_kv_load_req_ids
and
not
self
.
recompute_kv_load_failures
:
requests
=
[
self
.
requests
[
req_id
]
for
req_id
in
failed_kv_load_req_ids
]
self
.
finish_requests
(
failed_kv_load_req_ids
,
RequestStatus
.
FINISHED_ERROR
)
for
request
in
requests
:
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
request
.
request_id
,
new_token_ids
=
[],
finish_reason
=
request
.
get_finished_reason
(),
events
=
request
.
take_events
(),
trace_headers
=
request
.
trace_headers
,
num_cached_tokens
=
request
.
num_cached_tokens
,
)
)
# KV Connector: update state for finished KV Transfers.
if
kv_connector_output
:
self
.
_update_from_kv_xfer_finished
(
kv_connector_output
)
...
...
@@ -1610,8 +1651,11 @@ class Scheduler(SchedulerInterface):
self
.
_free_blocks
(
self
.
requests
[
req_id
])
def
_update_requests_with_invalid_blocks
(
self
,
requests
:
Iterable
[
Request
],
invalid_block_ids
:
set
[
int
]
)
->
tuple
[
set
[
str
],
int
]:
self
,
requests
:
Iterable
[
Request
],
invalid_block_ids
:
set
[
int
],
evict_blocks
:
bool
=
True
,
)
->
tuple
[
set
[
str
],
int
,
set
[
int
]]:
"""
Identify and update requests affected by invalid KV cache blocks.
...
...
@@ -1623,16 +1667,21 @@ class Scheduler(SchedulerInterface):
Args:
requests: The set of requests to scan for invalid blocks.
invalid_block_ids: IDs of invalid blocks.
evict_blocks: Whether to collect blocks for eviction (False for
async requests which aren't cached yet).
Returns:
tuple:
- affected_req_ids (set[str]): IDs of requests impacted by
invalid blocks.
- total_affected_tokens (int): Total number of tokens that must
be recomputed across all affected requests (for observability).
be recomputed across all affected requests.
- blocks_to_evict (set[int]): Block IDs to evict from cache,
including invalid blocks and downstream dependent blocks.
"""
affected_req_ids
:
set
[
str
]
=
set
()
total_affected_tokens
=
0
blocks_to_evict
:
set
[
int
]
=
set
()
# If a block is invalid and shared by multiple requests in the batch,
# these requests must be rescheduled, but only the first will recompute
# it. This set tracks blocks already marked for recomputation.
...
...
@@ -1690,6 +1739,9 @@ class Scheduler(SchedulerInterface):
)
total_affected_tokens
+=
num_affected_tokens
request
.
num_external_computed_tokens
-=
num_affected_tokens
# collect invalid block and all downstream dependent blocks
if
evict_blocks
:
blocks_to_evict
.
update
(
req_block_ids
[
idx
:])
if
is_affected
:
if
not
marked_invalid_block
:
...
...
@@ -1705,47 +1757,70 @@ class Scheduler(SchedulerInterface):
affected_req_ids
.
add
(
request
.
request_id
)
return
affected_req_ids
,
total_affected_tokens
return
affected_req_ids
,
total_affected_tokens
,
blocks_to_evict
def
_handle_invalid_blocks
(
self
,
invalid_block_ids
:
set
[
int
])
->
set
[
str
]:
total_requests_to_reschedule
=
0
total_tokens_to_reschedule
=
0
"""
Handle requests affected by invalid KV cache blocks.
Returns:
Set of affected request IDs to skip in update_from_output main loop.
"""
should_fail
=
not
self
.
recompute_kv_load_failures
#
--- H
andle async KV loads (
WAITING_FOR_REMOTE_KVS) ---
#
h
andle async KV loads (
not cached yet, evict_blocks=False)
async_load_reqs
=
(
req
for
req
in
self
.
waiting
if
req
.
status
==
RequestStatus
.
WAITING_FOR_REMOTE_KVS
)
async_
affect
ed_req_ids
,
num_
tokens_to_reschedule
=
(
async_
fail
ed_req_ids
,
num_
failed_tokens
,
_
=
(
self
.
_update_requests_with_invalid_blocks
(
async_load_reqs
,
invalid_block_ids
async_load_reqs
,
invalid_block_ids
,
evict_blocks
=
False
)
)
total_requests
_to_reschedule
+
=
len
(
async_
affect
ed_req_ids
)
total_
tokens_to_reschedule
+=
num_tokens_to_reschedule
total_
failed_
requests
=
len
(
async_
fail
ed_req_ids
)
total_
failed_tokens
=
num_failed_tokens
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes.
self
.
failed_recving_kv_req_ids
|=
async_affected_req_ids
# --- Handle sync KV loads (running requests) ---
sync_affected_req_ids
,
num_tokens_to_reschedule
=
(
self
.
_update_requests_with_invalid_blocks
(
self
.
running
,
invalid_block_ids
)
# handle sync loads (may be cached, collect blocks for eviction)
sync_failed_req_ids
,
num_failed_tokens
,
sync_blocks_to_evict
=
(
self
.
_update_requests_with_invalid_blocks
(
self
.
running
,
invalid_block_ids
,
evict_blocks
=
True
)
)
total_failed_requests
+=
len
(
sync_failed_req_ids
)
total_failed_tokens
+=
num_failed_tokens
total_requests_to_reschedule
+=
len
(
sync_affected_req_ids
)
total_tokens_to_reschedule
+=
num_tokens_to_reschedule
if
not
total_failed_requests
:
return
set
()
# evict invalid blocks and downstream dependent blocks from cache
# only when not using recompute policy (where blocks will be recomputed
# and reused by other requests sharing them)
if
sync_blocks_to_evict
and
not
self
.
recompute_kv_load_failures
:
self
.
kv_cache_manager
.
evict_blocks
(
sync_blocks_to_evict
)
if
should_fail
:
all_failed_req_ids
=
async_failed_req_ids
|
sync_failed_req_ids
logger
.
error
(
"Failing %d request(s) due to KV load failure "
"(failure_policy=fail, %d tokens affected). Request IDs: %s"
,
total_failed_requests
,
total_failed_tokens
,
all_failed_req_ids
,
)
return
all_failed_req_ids
if
total_requests_to_reschedule
:
logger
.
warning
(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected)."
,
total_requests
_to_reschedule
,
total_
tokens_to_reschedule
,
total_
failed_
requests
,
total_
failed_tokens
,
)
# Return the IDs of affected running requests to skip in
# update_from_output.
return
sync_affected_req_ids
# Mark async requests with KV load failures for retry once loading completes
self
.
failed_recving_kv_req_ids
|=
async_failed_req_ids
# Return sync affected IDs to skip in update_from_output
return
sync_failed_req_ids
vllm/v1/core/sched/utils.py
View file @
a3f8d5dd
...
...
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
torch
from
vllm.v1.request
import
Request
,
RequestStatus
...
...
@@ -39,14 +37,8 @@ def remove_all(lst: list, items_to_remove: set) -> list:
return
[
item
for
item
in
lst
if
item
not
in
items_to_remove
]
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
torch
.
Tensor
|
None
=
None
)
->
bool
:
if
request
.
pooling_params
:
if
pooler_output
is
not
None
:
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
return
False
def
check_stop
(
request
:
Request
,
max_model_len
:
int
)
->
bool
:
assert
not
request
.
pooling_params
sampling_params
=
request
.
sampling_params
assert
sampling_params
is
not
None
...
...
vllm/v1/engine/__init__.py
View file @
a3f8d5dd
...
...
@@ -19,24 +19,27 @@ from vllm.v1.serial_utils import UtilityResult
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS
=
(
"stop"
,
"length"
,
"abort"
)
FINISH_REASON_STRINGS
=
(
"stop"
,
"length"
,
"abort"
,
"error"
)
class
FinishReason
(
enum
.
IntEnum
):
"""
Reason a request finished - stop, length,
or ab
or
t
.
Reason a request finished - stop, length,
abort, or err
or.
Int rather than Str for more compact serialization.
stop - a stop string was emitted
length - max_tokens was consumed, or max_model_len was reached
abort - aborted for another reason
abort - aborted by client
error - retryable request-level internal error (e.g., KV load failure).
Invariant: always converted to 500 Internal Server Error.
"""
STOP
=
0
LENGTH
=
1
ABORT
=
2
ERROR
=
3
def
__str__
(
self
):
return
FINISH_REASON_STRINGS
[
self
.
value
]
...
...
vllm/v1/engine/async_llm.py
View file @
a3f8d5dd
...
...
@@ -26,7 +26,7 @@ from vllm.plugins.io_processors import get_io_processor
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tokenizers
import
TokenizerLike
,
init
_tokenizer_from_config
from
vllm.tokenizers
import
TokenizerLike
,
cached
_tokenizer_from_config
from
vllm.tracing
import
init_tracer
from
vllm.transformers_utils.config
import
maybe_register_config_serialize_by_value
from
vllm.usage.usage_lib
import
UsageContext
...
...
@@ -111,7 +111,7 @@ class AsyncLLM(EngineClient):
if
self
.
model_config
.
skip_tokenizer_init
:
tokenizer
=
None
else
:
tokenizer
=
init
_tokenizer_from_config
(
self
.
model_config
)
tokenizer
=
cached
_tokenizer_from_config
(
self
.
model_config
)
self
.
input_processor
=
InputProcessor
(
self
.
vllm_config
,
tokenizer
)
self
.
io_processor
=
get_io_processor
(
...
...
@@ -192,7 +192,7 @@ class AsyncLLM(EngineClient):
@
property
@
deprecated
(
"`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
"The old name will be removed in v0.1
3
."
"The old name will be removed in v0.1
4
."
)
def
processor
(
self
):
return
self
.
input_processor
...
...
@@ -701,10 +701,6 @@ class AsyncLLM(EngineClient):
def
tokenizer
(
self
)
->
TokenizerLike
|
None
:
return
self
.
input_processor
.
tokenizer
@
tokenizer
.
setter
def
tokenizer
(
self
,
tokenizer
:
TokenizerLike
|
None
)
->
None
:
self
.
input_processor
.
tokenizer
=
tokenizer
async
def
get_tokenizer
(
self
)
->
TokenizerLike
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
...
...
vllm/v1/engine/core.py
View file @
a3f8d5dd
...
...
@@ -211,6 +211,9 @@ class EngineCore:
freeze_gc_heap
()
# If enable, attach GC debugger after static variable freeze.
maybe_attach_gc_debug_callback
()
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache
()
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
...
...
@@ -672,10 +675,6 @@ class EngineCoreProc(EngineCore):
assert
addresses
.
coordinator_input
is
not
None
logger
.
info
(
"Waiting for READY message from DP Coordinator..."
)
# Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point)
enable_envs_cache
()
@
contextmanager
def
_perform_handshakes
(
self
,
...
...
vllm/v1/engine/input_processor.py
View file @
a3f8d5dd
...
...
@@ -19,7 +19,8 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tokenizers
import
MistralTokenizer
,
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
...
...
@@ -64,10 +65,6 @@ class InputProcessor:
def
tokenizer
(
self
)
->
TokenizerLike
|
None
:
return
self
.
input_preprocessor
.
tokenizer
@
tokenizer
.
setter
def
tokenizer
(
self
,
tokenizer
:
TokenizerLike
|
None
)
->
None
:
self
.
input_preprocessor
.
tokenizer
=
tokenizer
def
_validate_logprobs
(
self
,
params
:
SamplingParams
,
...
...
@@ -192,29 +189,39 @@ class InputProcessor:
def
_validate_single_prompt
(
single_prompt
:
dict
|
str
)
->
None
:
if
not
isinstance
(
single_prompt
,
dict
):
return
mm_data
=
single_prompt
.
get
(
"multi_modal_data"
)
mm_uuids
=
single_prompt
.
get
(
"multi_modal_uuids"
)
if
not
mm_data
or
not
mm_uuids
:
return
import
torch
def
_get_len
(
items
:
object
):
if
isinstance
(
items
,
dict
):
# Embedding inputs
return
_get_len
(
next
(
iter
(
items
.
values
())))
if
items
else
1
if
isinstance
(
items
,
list
):
return
len
(
items
)
if
isinstance
(
items
,
torch
.
Tensor
):
# To keep backwards compatibility for single item embedding input
return
1
if
getattr
(
items
,
"_is_single_item"
,
False
)
else
len
(
items
)
return
1
for
modality
,
items
in
mm_data
.
items
():
if
modality
in
mm_uuids
:
data_len
=
len
(
items
)
if
isinstance
(
items
,
list
)
else
1
uuid_len
=
(
len
(
mm_uuids
[
modality
])
if
isinstance
(
mm_uuids
[
modality
],
list
)
else
1
)
data_len
=
_get_len
(
items
)
uuid_len
=
_get_len
(
mm_uuids
[
modality
])
if
uuid_len
!=
data_len
:
raise
ValueError
(
f
"multi_modal_uuids for modality
'
{
modality
}
'
"
f
"multi_modal_uuids for modality
{
modality
!
r
}
"
"must have same length as data: got "
f
"
{
uuid_len
}
uuids vs "
f
"
{
data_len
}
items."
f
"
{
uuid_len
}
uuids vs
{
data_len
}
items."
)
else
:
raise
ValueError
(
f
"multi_modal_uuids for modality
'
{
modality
}
'
must "
f
"multi_modal_uuids for modality
{
modality
!
r
}
must "
"be provided if multi_modal_data is provided."
)
...
...
vllm/v1/engine/llm_engine.py
View file @
a3f8d5dd
...
...
@@ -23,7 +23,7 @@ from vllm.plugins.io_processors import get_io_processor
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tokenizers
import
TokenizerLike
,
init
_tokenizer_from_config
from
vllm.tokenizers
import
TokenizerLike
,
cached
_tokenizer_from_config
from
vllm.tracing
import
init_tracer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine
import
EngineCoreRequest
...
...
@@ -86,7 +86,7 @@ class LLMEngine:
if
self
.
model_config
.
skip_tokenizer_init
:
tokenizer
=
None
else
:
tokenizer
=
init
_tokenizer_from_config
(
self
.
model_config
)
tokenizer
=
cached
_tokenizer_from_config
(
self
.
model_config
)
self
.
input_processor
=
InputProcessor
(
self
.
vllm_config
,
tokenizer
)
self
.
io_processor
=
get_io_processor
(
...
...
@@ -139,7 +139,7 @@ class LLMEngine:
@
property
@
deprecated
(
"`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. "
"The old name will be removed in v0.1
3
."
"The old name will be removed in v0.1
4
."
)
def
processor
(
self
):
return
self
.
input_processor
...
...
@@ -358,10 +358,6 @@ class LLMEngine:
def
tokenizer
(
self
)
->
TokenizerLike
|
None
:
return
self
.
input_processor
.
tokenizer
@
tokenizer
.
setter
def
tokenizer
(
self
,
tokenizer
:
TokenizerLike
|
None
)
->
None
:
self
.
input_processor
.
tokenizer
=
tokenizer
def
get_tokenizer
(
self
)
->
TokenizerLike
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
...
...
vllm/v1/engine/processor.py
View file @
a3f8d5dd
...
...
@@ -10,7 +10,7 @@ def __getattr__(name: str):
warnings
.
warn
(
"`vllm.v1.engine.processor.Processor` has been moved to "
"`vllm.v1.engine.input_processor.InputProcessor`. "
"The old name will be removed in v0.1
3
."
,
"The old name will be removed in v0.1
4
."
,
DeprecationWarning
,
stacklevel
=
2
,
)
...
...
Prev
1
…
20
21
22
23
24
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment