Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dbcb0376
Commit
dbcb0376
authored
Feb 24, 2026
by
laibao
Browse files
feat(kvpress): Runner 接入 KV 位置与注意力元数据
parent
3d4f8753
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
819 additions
and
3 deletions
+819
-3
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+12
-0
vllm/v1/kv_compression/forward_context.py
vllm/v1/kv_compression/forward_context.py
+66
-0
vllm/v1/kv_compression/metadata.py
vllm/v1/kv_compression/metadata.py
+74
-0
vllm/v1/kv_compression/runner_buffers.py
vllm/v1/kv_compression/runner_buffers.py
+109
-0
vllm/v1/kv_compression/runner_prepare.py
vllm/v1/kv_compression/runner_prepare.py
+139
-0
vllm/v1/kv_compression/runner_prompt_compaction.py
vllm/v1/kv_compression/runner_prompt_compaction.py
+234
-0
vllm/v1/kv_compression/runner_step.py
vllm/v1/kv_compression/runner_step.py
+99
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+86
-3
No files found.
vllm/v1/attention/backend.py
View file @
dbcb0376
...
@@ -304,6 +304,8 @@ class CommonAttentionMetadata:
...
@@ -304,6 +304,8 @@ class CommonAttentionMetadata:
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens
:
int
num_actual_tokens
:
int
"""Total number of tokens in batch"""
"""Total number of tokens in batch"""
num_unpadded_tokens
:
int
|
None
=
None
"""Number of scheduled tokens excluding padding, if known."""
max_query_len
:
int
max_query_len
:
int
"""Longest query in batch"""
"""Longest query in batch"""
max_seq_len
:
int
max_seq_len
:
int
...
@@ -332,6 +334,16 @@ class CommonAttentionMetadata:
...
@@ -332,6 +334,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache
:
torch
.
Tensor
|
None
=
None
_num_computed_tokens_cache
:
torch
.
Tensor
|
None
=
None
# KV compression metadata (experimental, v1 paged attention only).
kv_compression_must_keep
:
torch
.
Tensor
|
None
=
None
kv_compression_topk_budget
:
torch
.
Tensor
|
None
=
None
kv_compression_topk_budget_max
:
int
|
None
=
None
kv_compression_prompt_end
:
torch
.
Tensor
|
None
=
None
kv_compression_prompt_lens
:
torch
.
Tensor
|
None
=
None
kv_compression_prompt_topk_keep
:
torch
.
Tensor
|
None
=
None
kv_compression_prompt_topk_keep_max
:
int
|
None
=
None
def
batch_size
(
self
)
->
int
:
def
batch_size
(
self
)
->
int
:
return
self
.
seq_lens
.
shape
[
0
]
return
self
.
seq_lens
.
shape
[
0
]
...
...
vllm/v1/kv_compression/forward_context.py
0 → 100644
View file @
dbcb0376
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
import
torch
_PROMPT_PAYLOAD_ATTR
=
"_kv_compression_prompt_payload"
_COMPACT_SLOTS_ATTR
=
"_kv_compression_compact_slots"
_COMPACT_SLOTS_BY_LAYER_ATTR
=
"_kv_compression_compact_slots_by_layer"
def
get_kv_compression_prompt_payload
(
forward_context
:
Any
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
return
getattr
(
forward_context
,
_PROMPT_PAYLOAD_ATTR
,
None
)
def
set_kv_compression_prompt_payload
(
forward_context
:
Any
,
payload
:
dict
[
str
,
torch
.
Tensor
],
)
->
None
:
setattr
(
forward_context
,
_PROMPT_PAYLOAD_ATTR
,
payload
)
def
_kv_compression_layer_key
(
layer
:
Any
)
->
str
:
layer_name
=
getattr
(
layer
,
"layer_name"
,
None
)
if
layer_name
is
None
:
layer_name
=
str
(
id
(
layer
))
return
str
(
layer_name
)
def
get_kv_compression_compact_slots
(
forward_context
:
Any
,
*
,
per_layer_topk
:
bool
,
layer
:
Any
,
)
->
Optional
[
torch
.
Tensor
]:
if
per_layer_topk
:
dst_by_layer
=
getattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
None
)
if
dst_by_layer
is
None
:
return
None
return
dst_by_layer
.
get
(
_kv_compression_layer_key
(
layer
))
return
getattr
(
forward_context
,
_COMPACT_SLOTS_ATTR
,
None
)
def
set_kv_compression_compact_slots
(
forward_context
:
Any
,
*
,
per_layer_topk
:
bool
,
layer
:
Any
,
dst
:
torch
.
Tensor
,
)
->
None
:
if
per_layer_topk
:
dst_by_layer
=
getattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
None
)
if
dst_by_layer
is
None
:
dst_by_layer
=
{}
setattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
dst_by_layer
)
dst_by_layer
[
_kv_compression_layer_key
(
layer
)]
=
dst
else
:
setattr
(
forward_context
,
_COMPACT_SLOTS_ATTR
,
dst
)
vllm/v1/kv_compression/metadata.py
0 → 100644
View file @
dbcb0376
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
torch
import
vllm.envs
as
envs
@
dataclass
class
KVCompressionAttentionMetadata
:
"""Per-batch KV compression metadata consumed by attention backends."""
must_keep
:
Optional
[
torch
.
Tensor
]
=
None
topk_budget
:
Optional
[
torch
.
Tensor
]
=
None
topk_budget_max
:
Optional
[
int
]
=
None
prompt_end
:
Optional
[
torch
.
Tensor
]
=
None
prompt_lens
:
Optional
[
torch
.
Tensor
]
=
None
prompt_topk_keep
:
Optional
[
torch
.
Tensor
]
=
None
prompt_topk_keep_max
:
Optional
[
int
]
=
None
def
build_kv_compression_attn_metadata
(
*
,
runner
:
Any
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
)
->
KVCompressionAttentionMetadata
:
"""Build KV compression metadata for one attention step.
This helper keeps backend code thin and centralizes the logic for selecting
between per-step compaction (scheme 1/2) and prompt-end one-shot scoring
(scheme 3).
"""
meta
=
KVCompressionAttentionMetadata
()
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
return
meta
# Scheme 1/2: compute compaction destinations every step.
if
getattr
(
runner
,
"kv_compression_needs_compaction"
,
False
):
meta
.
must_keep
=
runner
.
kv_compression_must_keep
[:
num_actual_tokens
]
meta
.
topk_budget
=
runner
.
kv_compression_topk_budget
[:
num_reqs
]
# Avoid device->host sync by reading from the CPU staging buffer.
if
num_reqs
>
0
:
meta
.
topk_budget_max
=
int
(
runner
.
kv_compression_topk_budget_np
[:
num_reqs
].
max
())
else
:
meta
.
topk_budget_max
=
0
return
meta
# Scheme 3: compute global prompt indices only on the last prefill chunk,
# and perform the actual cache compaction before the first decode step.
scheduler_config
=
getattr
(
runner
,
"scheduler_config"
,
None
)
if
scheduler_config
is
None
or
not
getattr
(
scheduler_config
,
"enable_chunked_prefill"
,
False
):
return
meta
if
num_reqs
<=
0
:
return
meta
if
not
runner
.
kv_compression_prompt_end_np
[:
num_reqs
].
any
():
return
meta
meta
.
prompt_end
=
runner
.
kv_compression_prompt_end
[:
num_reqs
]
meta
.
prompt_lens
=
runner
.
kv_compression_prompt_lens
[:
num_reqs
]
meta
.
prompt_topk_keep
=
runner
.
kv_compression_prompt_topk_keep
[:
num_reqs
]
meta
.
prompt_topk_keep_max
=
int
(
getattr
(
runner
,
"kv_compression_prompt_topk_keep_max"
,
0
)
or
0
)
return
meta
vllm/v1/kv_compression/runner_buffers.py
0 → 100644
View file @
dbcb0376
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
import
torch
def
init_kv_compression_runner_buffers
(
*
,
runner
:
Any
,
max_num_tokens
:
int
,
max_num_reqs
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
)
->
None
:
"""Initialize per-runner buffers used by KV compression.
This helper keeps `gpu_model_runner.py` focused on orchestration while
preserving the existing attribute-based access patterns.
"""
# KV positions are decoupled from logical positions when KV compression is
# enabled. Keep a separate buffer to avoid recomputing or overwriting the
# logical `positions_np` (used for RoPE / token lookup).
runner
.
kv_positions_cpu
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
runner
.
kv_positions_np
=
runner
.
kv_positions_cpu
.
numpy
()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
runner
.
kv_compression_must_keep_cpu
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
runner
.
kv_compression_must_keep_np
=
runner
.
kv_compression_must_keep_cpu
.
numpy
()
runner
.
kv_compression_must_keep
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
bool
,
device
=
device
,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
runner
.
kv_compression_topk_budget_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
runner
.
kv_compression_topk_budget_np
=
runner
.
kv_compression_topk_budget_cpu
.
numpy
()
runner
.
kv_compression_topk_budget
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
# Chunked-prefill prompt-end KV compression metadata (scheme 3).
# Per-request: whether this step finishes the prompt and should compute
# global prompt indices (score/topk) for a one-shot compaction.
runner
.
kv_compression_prompt_end_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
bool
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
runner
.
kv_compression_prompt_end_np
=
runner
.
kv_compression_prompt_end_cpu
.
numpy
()
runner
.
kv_compression_prompt_end
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
bool
,
device
=
device
,
)
# Per-request: prompt length (tokens) and Top-K keep count among prompt
# candidates (excluding protected prefix/suffix).
runner
.
kv_compression_prompt_lens_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
runner
.
kv_compression_prompt_lens_np
=
runner
.
kv_compression_prompt_lens_cpu
.
numpy
()
runner
.
kv_compression_prompt_lens
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
runner
.
kv_compression_prompt_topk_keep_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
runner
.
kv_compression_prompt_topk_keep_np
=
runner
.
kv_compression_prompt_topk_keep_cpu
.
numpy
()
runner
.
kv_compression_prompt_topk_keep
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
,
)
runner
.
kv_compression_prompt_topk_keep_max
=
None
# type: Optional[int]
vllm/v1/kv_compression/runner_prepare.py
0 → 100644
View file @
dbcb0376
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
numpy
as
np
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.budget
import
(
compute_prompt_topk_keep_total
,
compute_topk_budget_step
)
def
prepare_kv_compression_for_step
(
*
,
num_reqs
:
int
,
total_num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
np
.
ndarray
,
# [B] int32
cu_num_tokens
:
np
.
ndarray
,
# [B] int64/int32 cumulative scheduled tokens
req_indices
:
np
.
ndarray
,
# [T] int64, request index per token
arange
:
np
.
ndarray
,
# [T] int64, position-within-request per token
num_computed_tokens_cpu
:
np
.
ndarray
,
# [max_reqs] int32/int64
num_prompt_tokens
:
np
.
ndarray
,
# [max_reqs] int32/int64
num_kv_tokens_cpu
:
np
.
ndarray
,
# [max_reqs] int32/int64
kv_positions_np
:
np
.
ndarray
,
# [T] int64 (out)
must_keep_np
:
np
.
ndarray
,
# [T] bool (out; scheme 1/2 only)
topk_budget_np
:
np
.
ndarray
,
# [B] int32 (out; scheme 1/2 only)
prompt_end_np
:
np
.
ndarray
,
# [B] bool (out; scheme 3 only)
prompt_lens_np
:
np
.
ndarray
,
# [B] int32 (out; scheme 3 only)
prompt_topk_keep_np
:
np
.
ndarray
,
# [B] int32 (out; scheme 3 only)
chunked_prefill_enabled
:
bool
,
)
->
tuple
[
bool
,
Optional
[
int
]]:
"""Prepare KV compression metadata for a single model step (CPU-side).
Fills:
- `kv_positions_np`: per-token KV write positions (decoupled from logical
RoPE positions).
- Scheme 3 (chunked prefill): `prompt_end/prompt_lens/prompt_topk_keep`.
- Scheme 1/2 (non-chunked): `must_keep/topk_budget`.
Returns:
(needs_compaction, prompt_topk_keep_max)
"""
if
total_num_scheduled_tokens
<=
0
or
num_reqs
<=
0
:
return
False
,
None
# KV positions (where scheduled tokens are written before optional
# compaction).
np
.
add
(
num_kv_tokens_cpu
[
req_indices
],
arange
,
out
=
kv_positions_np
)
prompt_ratio
=
envs
.
VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget
=
envs
.
VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix
=
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix
=
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last
=
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
if
chunked_prefill_enabled
:
# Scheme 3: with chunked prefill, defer compaction until after the full
# prompt is ingested. Otherwise, the next prefill chunk would attend to
# a truncated history and quality can collapse.
prompt_end_np
.
fill
(
False
)
prompt_lens_np
.
fill
(
0
)
prompt_topk_keep_np
.
fill
(
0
)
for
req_idx
in
range
(
num_reqs
):
qlen
=
int
(
num_scheduled_tokens
[
req_idx
])
if
qlen
<=
0
:
continue
base_pos
=
int
(
num_computed_tokens_cpu
[
req_idx
])
prompt_len
=
int
(
num_prompt_tokens
[
req_idx
])
end_pos
=
base_pos
+
qlen
ends_prompt
=
(
base_pos
<
prompt_len
)
and
(
end_pos
>=
prompt_len
)
if
not
ends_prompt
:
continue
prompt_end_np
[
req_idx
]
=
True
prompt_lens_np
[
req_idx
]
=
prompt_len
prompt_topk_keep_np
[
req_idx
]
=
compute_prompt_topk_keep_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
prompt_topk_keep_max
=
int
(
prompt_topk_keep_np
[:
num_reqs
].
max
())
return
False
,
prompt_topk_keep_max
# Scheme 1/2: per-step compaction within the scheduled segment.
must_keep_np
.
fill
(
False
)
topk_budget_np
.
fill
(
0
)
for
req_idx
in
range
(
num_reqs
):
qlen
=
int
(
num_scheduled_tokens
[
req_idx
])
if
qlen
<=
0
:
continue
start
=
0
if
req_idx
==
0
else
int
(
cu_num_tokens
[
req_idx
-
1
])
end
=
int
(
cu_num_tokens
[
req_idx
])
assert
end
-
start
==
qlen
base_pos
=
int
(
num_computed_tokens_cpu
[
req_idx
])
prompt_len
=
int
(
num_prompt_tokens
[
req_idx
])
end_pos
=
base_pos
+
qlen
pos_in_req
=
arange
[
start
:
end
].
astype
(
np
.
int64
,
copy
=
False
)
pos
=
base_pos
+
pos_in_req
prompt_mask
=
pos
<
prompt_len
# Decode tokens are always kept.
must_keep
=
~
prompt_mask
if
np
.
any
(
prompt_mask
):
suffix_start
=
max
(
prompt_len
-
protected_suffix
,
0
)
must_keep
|=
prompt_mask
&
(
pos
<
protected_prefix
)
must_keep
|=
prompt_mask
&
(
pos
>=
suffix_start
)
if
keep_last
:
last
=
prompt_len
-
1
if
base_pos
<=
last
<
end_pos
:
must_keep
[
last
-
base_pos
]
=
True
topk_budget_np
[
req_idx
]
=
compute_topk_budget_step
(
prompt_len
=
prompt_len
,
start_pos
=
base_pos
,
end_pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
must_keep_np
[
start
:
end
]
=
must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally kept
# and there is no Top-K budget, KV compaction is a no-op and can be skipped.
needs_compaction
=
(
not
must_keep_np
.
all
())
or
(
topk_budget_np
>
0
).
any
()
return
bool
(
needs_compaction
),
None
vllm/v1/kv_compression/runner_prompt_compaction.py
0 → 100644
View file @
dbcb0376
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
import
torch
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.platforms
import
current_platform
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
CrossAttentionSpec
,
UniformTypeKVCacheSpecs
,
)
from
vllm.v1.kv_compression.forward_context
import
get_kv_compression_prompt_payload
def
stash_kv_compression_prompt_payload_to_requests
(
*
,
runner
:
Any
)
->
None
:
"""Persist prompt-end compaction indices from the forward context.
This is the runner-side half of chunked-prefill scheme 3:
flash_attn -> forward_context payload -> request state stash ->
(next step) one-shot KV compaction.
"""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
return
scheduler_config
=
getattr
(
runner
,
"scheduler_config"
,
None
)
if
scheduler_config
is
None
or
not
getattr
(
scheduler_config
,
"enable_chunked_prefill"
,
False
):
return
forward_context
=
get_forward_context
()
payload
=
get_kv_compression_prompt_payload
(
forward_context
)
if
payload
is
None
:
return
req_indices
=
payload
.
get
(
"req_indices"
)
idx_sorted
=
payload
.
get
(
"idx_sorted"
)
keep_len
=
payload
.
get
(
"keep_len"
)
prompt_lens
=
payload
.
get
(
"prompt_lens"
)
if
(
req_indices
is
None
or
idx_sorted
is
None
or
keep_len
is
None
or
prompt_lens
is
None
):
return
input_batch
=
getattr
(
runner
,
"input_batch"
,
None
)
if
input_batch
is
None
:
return
req_ids
=
getattr
(
input_batch
,
"req_ids"
,
None
)
if
req_ids
is
None
:
return
requests
=
getattr
(
runner
,
"requests"
,
None
)
if
requests
is
None
:
return
req_indices_cpu
=
req_indices
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
keep_cpu
=
keep_len
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
prompt_cpu
=
prompt_lens
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
for
i
,
b
in
enumerate
(
req_indices_cpu
):
if
b
<
0
or
b
>=
len
(
req_ids
):
continue
req_id
=
req_ids
[
b
]
if
req_id
is
None
:
continue
rs
=
requests
.
get
(
req_id
)
if
rs
is
None
:
continue
rs
.
kv_compression_prompt_idx_sorted
=
idx_sorted
[
i
]
rs
.
kv_compression_prompt_keep_len
=
int
(
keep_cpu
[
i
])
rs
.
kv_compression_prompt_prompt_len
=
int
(
prompt_cpu
[
i
])
def
maybe_apply_kv_compression_prompt_compaction
(
*
,
runner
:
Any
)
->
None
:
"""Apply one-shot prompt KV compaction before the first decode step."""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
return
if
not
current_platform
.
is_cuda_alike
():
return
scheduler_config
=
getattr
(
runner
,
"scheduler_config"
,
None
)
if
scheduler_config
is
None
or
not
getattr
(
scheduler_config
,
"enable_chunked_prefill"
,
False
):
return
input_batch
=
getattr
(
runner
,
"input_batch"
,
None
)
if
input_batch
is
None
:
return
requests
=
getattr
(
runner
,
"requests"
,
None
)
if
requests
is
None
:
return
pending_req_ids
:
list
[
str
]
=
[]
for
req_id
in
input_batch
.
req_ids
:
if
req_id
is
None
:
continue
rs
=
requests
.
get
(
req_id
)
if
rs
is
None
:
continue
if
rs
.
kv_compression_prompt_idx_sorted
is
None
:
continue
# Only apply once the prompt is fully ingested (decode stage).
if
rs
.
num_computed_tokens
<
rs
.
num_prompt_tokens
:
continue
pending_req_ids
.
append
(
req_id
)
if
not
pending_req_ids
:
return
device
=
runner
.
device
pending_states
:
list
[
tuple
[
str
,
torch
.
Tensor
,
int
]]
=
[]
for
req_id
in
pending_req_ids
:
rs
=
requests
[
req_id
]
keep
=
rs
.
kv_compression_prompt_keep_len
idx
=
rs
.
kv_compression_prompt_idx_sorted
if
keep
is
None
or
idx
is
None
:
continue
keep_i
=
int
(
keep
)
if
keep_i
<=
0
:
# No prompt tokens kept; clear and skip.
rs
.
kv_compression_prompt_idx_sorted
=
None
rs
.
kv_compression_prompt_keep_len
=
None
rs
.
kv_compression_prompt_prompt_len
=
None
continue
pending_states
.
append
((
req_id
,
idx
,
keep_i
))
if
not
pending_states
:
return
B
=
len
(
pending_states
)
keep_list
=
[
k
for
_
,
_
,
k
in
pending_states
]
K_max
=
max
(
keep_list
)
idx_batch
=
torch
.
zeros
((
B
,
K_max
),
device
=
device
,
dtype
=
torch
.
int32
)
for
i
,
(
_
,
row
,
k
)
in
enumerate
(
pending_states
):
idx_batch
[
i
,
:
k
]
=
row
[:
k
].
to
(
device
=
device
,
dtype
=
torch
.
int32
)
keep_tensor
=
torch
.
tensor
(
keep_list
,
device
=
device
,
dtype
=
torch
.
int32
)
from
vllm.v1.kv_compression.kv_cache_triton
import
(
front_compact_inplace_fa_triton
,
make_fa_cache_view
)
kv_cache_config
=
getattr
(
runner
,
"kv_cache_config"
,
None
)
if
kv_cache_config
is
None
:
return
# Apply compaction to every attention layer's KV cache in-place.
for
group_id
,
kv_cache_group_spec
in
enumerate
(
kv_cache_config
.
kv_cache_groups
):
max_blocks
=
0
for
req_id
,
_
,
_
in
pending_states
:
rs
=
requests
[
req_id
]
if
group_id
>=
len
(
rs
.
block_ids
):
continue
max_blocks
=
max
(
max_blocks
,
len
(
rs
.
block_ids
[
group_id
]))
if
max_blocks
==
0
:
continue
block_table_cpu
=
torch
.
zeros
((
B
,
max_blocks
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
for
i
,
(
req_id
,
_
,
_
)
in
enumerate
(
pending_states
):
rs
=
requests
[
req_id
]
if
group_id
>=
len
(
rs
.
block_ids
):
continue
ids
=
rs
.
block_ids
[
group_id
]
if
ids
:
block_table_cpu
[
i
,
:
len
(
ids
)]
=
torch
.
tensor
(
ids
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_table
=
block_table_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
static_forward_context
=
getattr
(
getattr
(
runner
,
"compilation_config"
,
None
),
"static_forward_context"
,
None
,
)
if
static_forward_context
is
None
:
continue
seen_cache_ptrs
:
set
[
int
]
=
set
()
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
# Skip non-self-attention caches (e.g., encoder/decoder cross-attn)
# and non-attention cache specs (e.g., Mamba).
kv_cache_spec
=
kv_cache_group_spec
.
kv_cache_spec
if
isinstance
(
kv_cache_spec
,
UniformTypeKVCacheSpecs
):
kv_cache_spec
=
kv_cache_spec
.
kv_cache_specs
.
get
(
layer_name
)
if
kv_cache_spec
is
None
or
not
isinstance
(
kv_cache_spec
,
AttentionSpec
):
continue
if
isinstance
(
kv_cache_spec
,
CrossAttentionSpec
):
continue
layer
=
static_forward_context
.
get
(
layer_name
)
if
layer
is
None
:
continue
kv_cache_list
=
getattr
(
layer
,
"kv_cache"
,
None
)
if
not
isinstance
(
kv_cache_list
,
list
)
or
not
kv_cache_list
:
continue
kv_cache
=
kv_cache_list
[
0
]
if
not
current_platform
.
is_rocm
():
if
not
isinstance
(
kv_cache
,
torch
.
Tensor
):
continue
cache_ptr
=
int
(
kv_cache
.
data_ptr
())
if
cache_ptr
in
seen_cache_ptrs
:
continue
seen_cache_ptrs
.
add
(
cache_ptr
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
if
(
not
isinstance
(
kv_cache
,
(
tuple
,
list
))
or
len
(
kv_cache
)
!=
2
):
continue
key_cache
,
value_cache
=
kv_cache
cache_ptr
=
int
(
key_cache
.
data_ptr
())
if
cache_ptr
in
seen_cache_ptrs
:
continue
seen_cache_ptrs
.
add
(
cache_ptr
)
k_view
,
v_view
=
make_fa_cache_view
(
key_cache
=
key_cache
,
value_cache
=
value_cache
)
front_compact_inplace_fa_triton
(
k_view
,
v_view
,
block_table
,
idx_batch
,
keep_tensor
,
)
# Clear pending state after successful compaction.
for
req_id
,
_
,
_
in
pending_states
:
rs
=
requests
.
get
(
req_id
)
if
rs
is
None
:
continue
rs
.
kv_compression_prompt_idx_sorted
=
None
rs
.
kv_compression_prompt_keep_len
=
None
rs
.
kv_compression_prompt_prompt_len
=
None
vllm/v1/kv_compression/runner_step.py
0 → 100644
View file @
dbcb0376
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
import
numpy
as
np
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.runner_prepare
import
prepare_kv_compression_for_step
def
maybe_prepare_kv_compression_for_runner_step
(
*
,
runner
:
Any
,
num_reqs
:
int
,
total_num_scheduled_tokens
:
int
,
num_scheduled_tokens
:
np
.
ndarray
,
# [B] int32
cu_num_tokens
:
np
.
ndarray
,
# [B] int64/int32
req_indices
:
np
.
ndarray
,
# [T] int64
arange
:
np
.
ndarray
,
# [T] int64
)
->
Optional
[
np
.
ndarray
]:
"""Prepare per-step KV compression metadata on CPU.
Returns the per-token KV positions (`kv_positions_np`) or None if KV
compression is disabled.
"""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
runner
.
kv_compression_needs_compaction
=
False
return
None
kv_positions_np
=
runner
.
kv_positions_np
[:
total_num_scheduled_tokens
]
must_keep_np
=
runner
.
kv_compression_must_keep_np
[:
total_num_scheduled_tokens
]
topk_budget_np
=
runner
.
kv_compression_topk_budget_np
[:
num_reqs
]
prompt_end_np
=
runner
.
kv_compression_prompt_end_np
[:
num_reqs
]
prompt_lens_np
=
runner
.
kv_compression_prompt_lens_np
[:
num_reqs
]
topk_keep_np
=
runner
.
kv_compression_prompt_topk_keep_np
[:
num_reqs
]
needs_compaction
,
prompt_topk_keep_max
=
prepare_kv_compression_for_step
(
num_reqs
=
num_reqs
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
cu_num_tokens
=
cu_num_tokens
,
req_indices
=
req_indices
,
arange
=
arange
,
num_computed_tokens_cpu
=
runner
.
input_batch
.
num_computed_tokens_cpu
,
num_prompt_tokens
=
runner
.
input_batch
.
num_prompt_tokens
,
num_kv_tokens_cpu
=
runner
.
input_batch
.
num_kv_tokens_cpu
,
kv_positions_np
=
kv_positions_np
,
must_keep_np
=
must_keep_np
,
topk_budget_np
=
topk_budget_np
,
prompt_end_np
=
prompt_end_np
,
prompt_lens_np
=
prompt_lens_np
,
prompt_topk_keep_np
=
topk_keep_np
,
chunked_prefill_enabled
=
runner
.
scheduler_config
.
enable_chunked_prefill
,
)
runner
.
kv_compression_needs_compaction
=
bool
(
needs_compaction
)
if
prompt_topk_keep_max
is
not
None
:
runner
.
kv_compression_prompt_topk_keep_max
=
int
(
prompt_topk_keep_max
)
return
kv_positions_np
def
maybe_copy_kv_compression_step_tensors_to_gpu
(
*
,
runner
:
Any
,
num_reqs
:
int
,
total_num_scheduled_tokens
:
int
,
non_blocking
:
bool
=
True
,
)
->
None
:
"""Stage per-step KV compression tensors to GPU if needed."""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
return
if
runner
.
scheduler_config
.
enable_chunked_prefill
:
runner
.
kv_compression_prompt_end
[:
num_reqs
].
copy_
(
runner
.
kv_compression_prompt_end_cpu
[:
num_reqs
],
non_blocking
=
non_blocking
,
)
runner
.
kv_compression_prompt_lens
[:
num_reqs
].
copy_
(
runner
.
kv_compression_prompt_lens_cpu
[:
num_reqs
],
non_blocking
=
non_blocking
,
)
runner
.
kv_compression_prompt_topk_keep
[:
num_reqs
].
copy_
(
runner
.
kv_compression_prompt_topk_keep_cpu
[:
num_reqs
],
non_blocking
=
non_blocking
,
)
return
if
runner
.
kv_compression_needs_compaction
:
runner
.
kv_compression_must_keep
[:
total_num_scheduled_tokens
].
copy_
(
runner
.
kv_compression_must_keep_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
non_blocking
,
)
runner
.
kv_compression_topk_budget
[:
num_reqs
].
copy_
(
runner
.
kv_compression_topk_budget_cpu
[:
num_reqs
],
non_blocking
=
non_blocking
,
)
vllm/v1/worker/gpu_model_runner.py
View file @
dbcb0376
...
@@ -97,6 +97,7 @@ from vllm.utils.torch_utils import (
...
@@ -97,6 +97,7 @@ from vllm.utils.torch_utils import (
get_dtype_size
,
get_dtype_size
,
kv_cache_dtype_str_to_dtype
,
kv_cache_dtype_str_to_dtype
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
...
@@ -114,6 +115,16 @@ from vllm.v1.attention.backends.utils import (
...
@@ -114,6 +115,16 @@ from vllm.v1.attention.backends.utils import (
)
)
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.kv_compression.runner_buffers
import
init_kv_compression_runner_buffers
from
vllm.v1.kv_compression.metadata
import
build_kv_compression_attn_metadata
from
vllm.v1.kv_compression.runner_prompt_compaction
import
(
maybe_apply_kv_compression_prompt_compaction
,
stash_kv_compression_prompt_payload_to_requests
,
)
from
vllm.v1.kv_compression.runner_step
import
(
maybe_copy_kv_compression_step_tensors_to_gpu
,
maybe_prepare_kv_compression_for_runner_step
,
)
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
AttentionSpec
,
ChunkedLocalAttentionSpec
,
ChunkedLocalAttentionSpec
,
...
@@ -562,6 +573,16 @@ class GPUModelRunner(
...
@@ -562,6 +573,16 @@ class GPUModelRunner(
# Persistent buffers for CUDA graphs.
# Persistent buffers for CUDA graphs.
self
.
input_ids
=
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
)
self
.
input_ids
=
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
int32
)
self
.
positions
=
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
)
self
.
positions
=
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
current_platform
.
is_cuda_alike
():
init_kv_compression_runner_buffers
(
runner
=
self
,
max_num_tokens
=
self
.
max_num_tokens
,
max_num_reqs
=
self
.
max_num_reqs
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
else
:
self
.
kv_compression_needs_compaction
=
False
self
.
query_start_loc
=
self
.
_make_buffer
(
self
.
query_start_loc
=
self
.
_make_buffer
(
self
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
self
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
)
)
...
@@ -953,6 +974,7 @@ class GPUModelRunner(
...
@@ -953,6 +974,7 @@ class GPUModelRunner(
generator
=
generator
,
generator
=
generator
,
block_ids
=
new_req_data
.
block_ids
,
block_ids
=
new_req_data
.
block_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
num_kv_tokens
=
new_req_data
.
num_kv_tokens
,
output_token_ids
=
[],
output_token_ids
=
[],
lora_request
=
new_req_data
.
lora_request
,
lora_request
=
new_req_data
.
lora_request
,
)
)
...
@@ -987,6 +1009,7 @@ class GPUModelRunner(
...
@@ -987,6 +1009,7 @@ class GPUModelRunner(
for
i
,
req_id
in
enumerate
(
req_data
.
req_ids
):
for
i
,
req_id
in
enumerate
(
req_data
.
req_ids
):
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
num_computed_tokens
=
req_data
.
num_computed_tokens
[
i
]
num_kv_tokens
=
req_data
.
num_kv_tokens
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
new_block_ids
=
req_data
.
new_block_ids
[
i
]
resumed_from_preemption
=
req_id
in
req_data
.
resumed_req_ids
resumed_from_preemption
=
req_id
in
req_data
.
resumed_req_ids
num_output_tokens
=
req_data
.
num_output_tokens
[
i
]
num_output_tokens
=
req_data
.
num_output_tokens
[
i
]
...
@@ -1014,10 +1037,12 @@ class GPUModelRunner(
...
@@ -1014,10 +1037,12 @@ class GPUModelRunner(
num_accepted
=
valid_sampled_token_count
[
prev_req_index
]
-
1
num_accepted
=
valid_sampled_token_count
[
prev_req_index
]
-
1
num_rejected
=
req_state
.
prev_num_draft_len
-
num_accepted
num_rejected
=
req_state
.
prev_num_draft_len
-
num_accepted
num_computed_tokens
-=
num_rejected
num_computed_tokens
-=
num_rejected
num_kv_tokens
-=
num_rejected
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
# Update the cached states.
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
req_state
.
num_computed_tokens
=
num_computed_tokens
req_state
.
num_kv_tokens
=
num_kv_tokens
if
not
is_last_rank
:
if
not
is_last_rank
:
# When using PP, the scheduler sends the sampled tokens back,
# When using PP, the scheduler sends the sampled tokens back,
...
@@ -1074,6 +1099,7 @@ class GPUModelRunner(
...
@@ -1074,6 +1099,7 @@ class GPUModelRunner(
# Update the persistent batch.
# Update the persistent batch.
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
num_computed_tokens
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
num_computed_tokens
self
.
input_batch
.
num_kv_tokens_cpu
[
req_index
]
=
num_kv_tokens
if
new_block_ids
is
not
None
:
if
new_block_ids
is
not
None
:
self
.
input_batch
.
block_table
.
append_row
(
new_block_ids
,
req_index
)
self
.
input_batch
.
block_table
.
append_row
(
new_block_ids
,
req_index
)
...
@@ -1183,6 +1209,7 @@ class GPUModelRunner(
...
@@ -1183,6 +1209,7 @@ class GPUModelRunner(
req_state
.
pooling_params
=
new_req_data
.
pooling_params
req_state
.
pooling_params
=
new_req_data
.
pooling_params
req_state
.
block_ids
=
new_req_data
.
block_ids
req_state
.
block_ids
=
new_req_data
.
block_ids
req_state
.
num_computed_tokens
=
new_req_data
.
num_computed_tokens
req_state
.
num_computed_tokens
=
new_req_data
.
num_computed_tokens
req_state
.
num_kv_tokens
=
new_req_data
.
num_kv_tokens
req_state
.
num_prompt_tokens
=
length_from_prompt_token_ids_or_embeds
(
req_state
.
num_prompt_tokens
=
length_from_prompt_token_ids_or_embeds
(
req_state
.
prompt_token_ids
,
req_state
.
prompt_embeds
req_state
.
prompt_token_ids
,
req_state
.
prompt_embeds
)
)
...
@@ -1482,6 +1509,16 @@ class GPUModelRunner(
...
@@ -1482,6 +1509,16 @@ class GPUModelRunner(
out
=
positions_np
,
out
=
positions_np
,
)
)
kv_positions_np
=
maybe_prepare_kv_compression_for_runner_step
(
runner
=
self
,
num_reqs
=
num_reqs
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
num_scheduled_tokens
=
num_scheduled_tokens
,
cu_num_tokens
=
cu_num_tokens
,
req_indices
=
req_indices
,
arange
=
arange
,
)
# Calculate M-RoPE positions.
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -1557,8 +1594,19 @@ class GPUModelRunner(
...
@@ -1557,8 +1594,19 @@ class GPUModelRunner(
output_idx
+=
num_sched
output_idx
+=
num_sched
self
.
input_batch
.
block_table
.
compute_slot_mapping
(
req_indices
,
positions_np
)
positions_for_slot_mapping
=
(
kv_positions_np
if
kv_positions_np
is
not
None
else
positions_np
)
self
.
input_batch
.
block_table
.
compute_slot_mapping
(
req_indices
,
positions_for_slot_mapping
)
self
.
input_batch
.
block_table
.
commit_slot_mapping
(
total_num_scheduled_tokens
)
self
.
input_batch
.
block_table
.
commit_slot_mapping
(
total_num_scheduled_tokens
)
maybe_copy_kv_compression_step_tensors_to_gpu
(
runner
=
self
,
num_reqs
=
num_reqs
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
non_blocking
=
True
,
)
# Prepare the attention metadata.
# Prepare the attention metadata.
self
.
query_start_loc
.
np
[
0
]
=
0
self
.
query_start_loc
.
np
[
0
]
=
0
...
@@ -1569,9 +1617,15 @@ class GPUModelRunner(
...
@@ -1569,9 +1617,15 @@ class GPUModelRunner(
self
.
query_start_loc
.
copy_to_gpu
()
self
.
query_start_loc
.
copy_to_gpu
()
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
self
.
seq_lens
.
np
[:
num_reqs
]
=
(
logical_
seq_lens
_
np
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
)
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
self
.
seq_lens
.
np
[:
num_reqs
]
=
(
self
.
input_batch
.
num_kv_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
)
else
:
self
.
seq_lens
.
np
[:
num_reqs
]
=
logical_seq_lens_np
# Fill unused with 0 for full cuda graph mode.
# Fill unused with 0 for full cuda graph mode.
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_to_gpu
()
self
.
seq_lens
.
copy_to_gpu
()
...
@@ -1582,7 +1636,7 @@ class GPUModelRunner(
...
@@ -1582,7 +1636,7 @@ class GPUModelRunner(
# Record which requests should not be sampled,
# Record which requests should not be sampled,
# so that we could clear the sampled tokens before returning
# so that we could clear the sampled tokens before returning
self
.
discard_request_mask
.
np
[:
num_reqs
]
=
(
self
.
discard_request_mask
.
np
[:
num_reqs
]
=
(
self
.
seq_lens
.
np
[:
num_reqs
]
<
num_tokens_np
logical_
seq_lens
_
np
<
num_tokens_np
)
)
self
.
discard_request_mask
.
copy_to_gpu
(
num_reqs
)
self
.
discard_request_mask
.
copy_to_gpu
(
num_reqs
)
...
@@ -1749,12 +1803,25 @@ class GPUModelRunner(
...
@@ -1749,12 +1803,25 @@ class GPUModelRunner(
],
],
num_reqs
=
num_reqs_padded
,
num_reqs
=
num_reqs_padded
,
num_actual_tokens
=
num_tokens_padded
,
num_actual_tokens
=
num_tokens_padded
,
num_unpadded_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_gid_0
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
causal
=
True
,
)
)
kv_meta
=
build_kv_compression_attn_metadata
(
runner
=
self
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
)
cm_base
.
kv_compression_must_keep
=
kv_meta
.
must_keep
cm_base
.
kv_compression_topk_budget
=
kv_meta
.
topk_budget
cm_base
.
kv_compression_topk_budget_max
=
kv_meta
.
topk_budget_max
cm_base
.
kv_compression_prompt_end
=
kv_meta
.
prompt_end
cm_base
.
kv_compression_prompt_lens
=
kv_meta
.
prompt_lens
cm_base
.
kv_compression_prompt_topk_keep
=
kv_meta
.
prompt_topk_keep
cm_base
.
kv_compression_prompt_topk_keep_max
=
kv_meta
.
prompt_topk_keep_max
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
=
get_dcp_local_seq_lens
(
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
=
get_dcp_local_seq_lens
(
...
@@ -3510,6 +3577,10 @@ class GPUModelRunner(
...
@@ -3510,6 +3577,10 @@ class GPUModelRunner(
self
.
model_config
.
is_encoder_decoder
and
num_encoder_reqs
>
0
self
.
model_config
.
is_encoder_decoder
and
num_encoder_reqs
>
0
)
)
# Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
maybe_apply_kv_compression_prompt_compaction
(
runner
=
self
)
# Run the model.
# Run the model.
# Use persistent buffers for CUDA graphs.
# Use persistent buffers for CUDA graphs.
with
(
with
(
...
@@ -3534,6 +3605,7 @@ class GPUModelRunner(
...
@@ -3534,6 +3605,7 @@ class GPUModelRunner(
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
**
model_kwargs
,
**
model_kwargs
,
)
)
stash_kv_compression_prompt_payload_to_requests
(
runner
=
self
)
with
record_function_or_nullcontext
(
"gpu_model_runner: postprocess"
):
with
record_function_or_nullcontext
(
"gpu_model_runner: postprocess"
):
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
...
@@ -5273,6 +5345,17 @@ class GPUModelRunner(
...
@@ -5273,6 +5345,17 @@ class GPUModelRunner(
attention_backend_maps
.
append
(
attn_backends
[
0
])
attention_backend_maps
.
append
(
attn_backends
[
0
])
attention_backend_list
.
append
(
attn_backends
[
1
])
attention_backend_list
.
append
(
attn_backends
[
1
])
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
current_platform
.
is_cuda_alike
():
for
attn_backend_set
in
attention_backend_list
:
for
attn_backend
in
attn_backend_set
:
if
attn_backend
.
get_name
()
!=
"FLASH_ATTN"
:
raise
ValueError
(
"KV compression currently requires the FLASH_ATTN "
"attention backend. "
f
"Got
{
attn_backend
.
get_name
()
}
"
f
"(
{
attn_backend
.
full_cls_name
()
}
)."
)
# Resolve cudagraph_mode before actually initialize metadata_builders
# Resolve cudagraph_mode before actually initialize metadata_builders
self
.
_check_and_update_cudagraph_mode
(
self
.
_check_and_update_cudagraph_mode
(
attention_backend_list
,
kv_cache_config
.
kv_cache_groups
attention_backend_list
,
kv_cache_config
.
kv_cache_groups
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment