Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b7160c6
Commit
2b7160c6
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.0
parent
fa718036
Changes
305
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4122 additions
and
0 deletions
+4122
-0
vllm/kvprune_legacy_save/compression/prefill_registry.py
vllm/kvprune_legacy_save/compression/prefill_registry.py
+201
-0
vllm/kvprune_legacy_save/compression/snapkv.py
vllm/kvprune_legacy_save/compression/snapkv.py
+545
-0
vllm/kvprune_legacy_save/compression/snapkv_origin.py
vllm/kvprune_legacy_save/compression/snapkv_origin.py
+449
-0
vllm/kvprune_legacy_save/config/__init__.py
vllm/kvprune_legacy_save/config/__init__.py
+7
-0
vllm/kvprune_legacy_save/config/constants.py
vllm/kvprune_legacy_save/config/constants.py
+7
-0
vllm/kvprune_legacy_save/config/engine_config.py
vllm/kvprune_legacy_save/config/engine_config.py
+129
-0
vllm/kvprune_legacy_save/config/sampling_params.py
vllm/kvprune_legacy_save/config/sampling_params.py
+11
-0
vllm/kvprune_legacy_save/core/__init__.py
vllm/kvprune_legacy_save/core/__init__.py
+45
-0
vllm/kvprune_legacy_save/core/block_budget.py
vllm/kvprune_legacy_save/core/block_budget.py
+69
-0
vllm/kvprune_legacy_save/core/compression_bridge.py
vllm/kvprune_legacy_save/core/compression_bridge.py
+60
-0
vllm/kvprune_legacy_save/core/flash_integration.py
vllm/kvprune_legacy_save/core/flash_integration.py
+92
-0
vllm/kvprune_legacy_save/core/llm_engine.py
vllm/kvprune_legacy_save/core/llm_engine.py
+441
-0
vllm/kvprune_legacy_save/core/memory_manager.py
vllm/kvprune_legacy_save/core/memory_manager.py
+237
-0
vllm/kvprune_legacy_save/core/model_runner.py
vllm/kvprune_legacy_save/core/model_runner.py
+794
-0
vllm/kvprune_legacy_save/core/runtime.py
vllm/kvprune_legacy_save/core/runtime.py
+130
-0
vllm/kvprune_legacy_save/core/scheduler.py
vllm/kvprune_legacy_save/core/scheduler.py
+259
-0
vllm/kvprune_legacy_save/integration/__init__.py
vllm/kvprune_legacy_save/integration/__init__.py
+7
-0
vllm/kvprune_legacy_save/integration/compactor_shared.py
vllm/kvprune_legacy_save/integration/compactor_shared.py
+140
-0
vllm/kvprune_legacy_save/integration/compressed_generate.py
vllm/kvprune_legacy_save/integration/compressed_generate.py
+447
-0
vllm/kvprune_legacy_save/integration/compression_params.py
vllm/kvprune_legacy_save/integration/compression_params.py
+52
-0
No files found.
Too many changes to show.
To preserve performance only
305 of 305+
files are displayed.
Plain diff
Email patch
vllm/kvprune_legacy_save/compression/prefill_registry.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Map COMPRESSION_REGISTRY scoring to prefill top-k indices (with tail fallback)."""
from
__future__
import
annotations
import
logging
import
torch
from
vllm.kvprune.compression
import
COMPRESSION_REGISTRY
from
vllm.kvprune.compression.compression_config
import
CompressionMethod
from
vllm.kvprune.utils.context
import
CompressionContext
,
Context
logger
=
logging
.
getLogger
(
__name__
)
def
_scores_to_topk_pair_indices
(
cu_seqlens
:
torch
.
Tensor
,
num_reqs
:
int
,
hkv
:
int
,
scores
:
torch
.
Tensor
,
compression_ratio
:
float
|
torch
.
Tensor
,
max_sel
:
int
,
device
:
torch
.
device
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Select (token, head) pairs with highest scores per request up to budget."""
if
scores
.
dim
()
==
1
:
scores
=
scores
.
unsqueeze
(
-
1
).
expand
(
-
1
,
hkv
)
elif
scores
.
dim
()
>
2
:
scores
=
scores
.
reshape
(
scores
.
shape
[
0
],
-
1
)[:,
:
hkv
]
indices
=
torch
.
zeros
(
num_reqs
,
max_sel
,
dtype
=
torch
.
int32
,
device
=
device
)
n_pairs
=
torch
.
zeros
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
cu_cpu
=
cu_seqlens
[:
num_reqs
+
1
].
detach
()
for
b
in
range
(
num_reqs
):
start
=
int
(
cu_cpu
[
b
].
item
())
end
=
int
(
cu_cpu
[
b
+
1
].
item
())
chunk_len
=
end
-
start
if
chunk_len
<=
0
:
continue
if
isinstance
(
compression_ratio
,
torch
.
Tensor
):
r_b
=
float
(
compression_ratio
[
b
].
item
())
else
:
r_b
=
compression_ratio
k_tok
=
max
(
1
,
int
(
round
(
chunk_len
*
r_b
)))
k_tok
=
min
(
k_tok
,
chunk_len
)
budget
=
min
(
k_tok
*
hkv
,
max_sel
)
flat_scores
:
list
[
tuple
[
float
,
int
]]
=
[]
for
tok
in
range
(
start
,
end
):
for
h
in
range
(
hkv
):
if
scores
.
dim
()
==
2
:
s
=
float
(
scores
[
tok
,
h
].
item
())
else
:
s
=
float
(
scores
[
tok
].
item
())
idx
=
tok
*
hkv
+
h
flat_scores
.
append
((
s
,
idx
))
flat_scores
.
sort
(
key
=
lambda
x
:
-
x
[
0
])
n
=
min
(
budget
,
len
(
flat_scores
))
if
n
>
0
:
chosen
=
[
x
[
1
]
for
x
in
flat_scores
[:
n
]]
indices
[
b
,
:
n
]
=
torch
.
tensor
(
chosen
,
dtype
=
torch
.
int32
,
device
=
device
)
n_pairs
[
b
]
=
n
return
indices
,
n_pairs
def
try_topk_indices_from_registry
(
method
:
CompressionMethod
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
cu
:
torch
.
Tensor
,
num_reqs
:
int
,
compression_ratio
:
torch
.
Tensor
,
max_sel
:
int
,
device
:
torch
.
device
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
:
"""Return (indices, n_pairs) using registry scoring, or None to use tail fallback."""
if
method
==
CompressionMethod
.
NONE
:
return
None
num_kv_heads
=
key
.
shape
[
1
]
n_tokens
,
hkv
,
d
=
key
.
shape
if
n_tokens
<=
0
or
hkv
<=
0
:
return
None
k_flat
=
key
.
reshape
(
n_tokens
,
hkv
,
d
)
v_flat
=
value
.
reshape
(
n_tokens
,
hkv
,
d
)
context_lens
=
[]
cu_cpu
=
cu
[:
num_reqs
+
1
].
detach
().
cpu
()
for
b
in
range
(
num_reqs
):
context_lens
.
append
(
int
(
cu_cpu
[
b
+
1
].
item
()
-
cu_cpu
[
b
].
item
()))
max_seqlen_q
=
int
((
cu_cpu
[
1
:
num_reqs
+
1
]
-
cu_cpu
[:
num_reqs
]).
max
().
item
())
if
method
==
CompressionMethod
.
COMPACTOR
:
try
:
k_proj
=
min
(
64
,
d
)
phi
=
torch
.
randn
(
d
,
k_proj
,
device
=
key
.
device
,
dtype
=
torch
.
float32
)
cc
=
CompressionContext
(
compression_method
=
CompressionMethod
.
COMPACTOR
,
context_lens
=
context_lens
,
PHI
=
phi
,
compression_chunk_size
=
512
,
protected_first_tokens
=
[
0
]
*
num_reqs
,
protected_last_tokens
=
[
0
]
*
num_reqs
,
)
ctx
=
Context
(
is_prefill
=
True
,
do_compression
=
True
,
cu_seqlens_q
=
cu
,
max_seqlen_q
=
max_seqlen_q
,
compression_context
=
cc
,
)
cls
=
COMPRESSION_REGISTRY
[
CompressionMethod
.
COMPACTOR
]
q_dummy
=
torch
.
zeros_like
(
k_flat
)
scores
=
cls
.
pre_rope_scoring
(
q_dummy
,
k_flat
,
v_flat
,
context
=
ctx
,
)
if
scores
is
None
:
return
None
return
_scores_to_topk_pair_indices
(
cu
,
num_reqs
,
hkv
,
scores
,
compression_ratio
,
max_sel
,
device
)
except
Exception
:
logger
.
debug
(
"Compactor pre_rope scoring failed; using tail fallback"
,
exc_info
=
True
)
return
None
if
method
==
CompressionMethod
.
CRITICALADAKV
:
try
:
k_proj
=
min
(
64
,
d
)
phi
=
torch
.
randn
(
d
,
k_proj
,
device
=
key
.
device
,
dtype
=
torch
.
float32
)
cc
=
CompressionContext
(
compression_method
=
CompressionMethod
.
CRITICALADAKV
,
context_lens
=
context_lens
,
PHI
=
phi
,
compression_chunk_size
=
512
,
protected_first_tokens
=
[
0
]
*
num_reqs
,
protected_last_tokens
=
[
0
]
*
num_reqs
,
)
ctx
=
Context
(
is_prefill
=
True
,
do_compression
=
True
,
cu_seqlens_q
=
cu
,
max_seqlen_q
=
max_seqlen_q
,
compression_context
=
cc
,
)
cls_ada
=
COMPRESSION_REGISTRY
[
CompressionMethod
.
CRITICALADAKV
]
q_dummy
=
torch
.
zeros_like
(
k_flat
)
pre_scores
=
cls_ada
.
pre_rope_scoring
(
q_dummy
,
k_flat
,
v_flat
,
context
=
ctx
)
scores
=
cls_ada
.
post_rope_scoring
(
q_dummy
,
k_flat
,
v_flat
,
pre_scores
,
context
=
ctx
)
if
scores
is
None
:
return
None
return
_scores_to_topk_pair_indices
(
cu
,
num_reqs
,
hkv
,
scores
,
compression_ratio
,
max_sel
,
device
)
except
Exception
:
logger
.
debug
(
"CriticalAdaKV registry path failed; using tail fallback"
,
exc_info
=
True
)
return
None
if
method
==
CompressionMethod
.
SNAPKV
:
try
:
cc
=
CompressionContext
(
compression_method
=
CompressionMethod
.
SNAPKV
)
ctx
=
Context
(
is_prefill
=
True
,
do_compression
=
True
,
cu_seqlens_q
=
cu
,
cu_seqlens_k
=
cu
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
compression_context
=
cc
,
)
cls
=
COMPRESSION_REGISTRY
[
CompressionMethod
.
SNAPKV
]
q_dummy
=
torch
.
zeros_like
(
k_flat
)
scores
=
cls
.
post_rope_scoring
(
q_dummy
,
k_flat
,
v_flat
,
None
,
context
=
ctx
,
)
if
scores
is
None
:
return
None
return
_scores_to_topk_pair_indices
(
cu
,
num_reqs
,
hkv
,
scores
,
compression_ratio
,
max_sel
,
device
)
except
Exception
:
logger
.
debug
(
"SnapKV registry path failed; using tail fallback"
,
exc_info
=
True
)
return
None
return
None
__all__
=
[
"try_topk_indices_from_registry"
]
vllm/kvprune_legacy_save/compression/snapkv.py
0 → 100644
View file @
2b7160c6
import
math
from
typing
import
Optional
import
torch
import
triton
from
triton
import
language
as
tl
from
vllm.kvprune.compression.common
import
BaseCompressionMethod
from
vllm.kvprune.utils.helpers
import
maybe_execute_in_stream
from
vllm.kvprune.utils.triton_compat
import
autotune
as
triton_autotune
# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
DEFAULT_SNAPKV_WINDOW_SIZE
=
64
DEFAULT_SNAPKV_KERNEL_SIZE
=
5
class
SnapKVCompression
(
BaseCompressionMethod
):
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
scores
=
maybe_execute_in_stream
(
query_aware_key_scores
,
q
,
k
,
context
.
cu_seqlens_q
,
context
.
cu_seqlens_k
,
w
=
DEFAULT_SNAPKV_WINDOW_SIZE
,
kernel_size
=
DEFAULT_SNAPKV_KERNEL_SIZE
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
bq
in
[
32
,
64
]
for
bk
in
[
32
,
64
]
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
3
,
4
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"ROWS_MAX"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_lse_and_store_logits_kernel
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
# int32 pointers
out_m
,
out_S
,
# [B, Hk, ROWS_MAX] float32
LOGITS
,
# [Nk, Hk, ROWS_MAX] float32
sm_scale
,
# float
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
ROWS_MAX
,
):
# program ids
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
# row-tile id
# batch segment bounds
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
# rows for this (b,hk)
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale
=
sm_scale
*
1.4426950408889634
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
# map row -> (q_idx, hq_local)
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
offs_d
=
tl
.
arange
(
0
,
D
)
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:]
)
q_rows
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
# Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
k_ptrs
=
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
],
other
=
0.0
)
# [BK, D]
s
=
tl
.
dot
(
q_rows
,
k_blk
.
T
)
*
qk_scale
# [BQ, BK]
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
# Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
causal_ok
=
nk
[
None
,
:]
<=
q_idx
[:,
None
]
s
=
tl
.
where
(
causal_ok
,
s
,
-
float
(
"inf"
))
# store prefix logits only (for marginal probs on prefix keys)
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
store_mask
=
kmask
&
(
nk
<
k_eff_end
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
store_mask
[:,
None
]
&
row_mask
[
None
,
:])
# log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
cur_max
=
tl
.
max
(
s
,
1
)
# [BQ]
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
# store m,S for these rows
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
})
for
bq
in
[
16
,
32
,
64
]
for
bk
in
[
32
,
64
,
128
]
],
key
=
[
"HK"
,
"HQ"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_prefix_probs_kernel
(
cu_k
,
w_b
,
in_m
,
in_S
,
# [B, Hk, ROWS_MAX] f32
LOGITS
,
# [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
PROBS
,
# [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
#
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
STRIDE_PB_NK
,
STRIDE_PB_HK
,
STRIDE_PB_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
row0
in
tl
.
range
(
0
,
rows_b
,
BLOCK_Q
):
r_idx
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
rmask
=
r_idx
<
rows_b
m_ptr
=
in_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_ptr
=
in_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
m
=
tl
.
load
(
m_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
mask
=
rmask
,
other
=-
float
(
"inf"
),
)
S
=
tl
.
load
(
S_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
mask
=
rmask
,
other
=
0.0
)
valid_row
=
S
>
0
m
=
tl
.
where
(
valid_row
,
m
,
0.0
)
S
=
tl
.
where
(
valid_row
,
S
,
1.0
)
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
s_T
=
tl
.
load
(
log_ptrs
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:],
other
=-
float
(
"inf"
)
)
# [BK, BQ]
probs_T
=
tl
.
math
.
exp2
(
s_T
-
m
[
None
,
:])
/
S
[
None
,
:]
probs_T
=
tl
.
where
(
valid_row
[
None
,
:],
probs_T
,
0.0
)
prob_ptrs
=
(
PROBS
+
nk
[:,
None
]
*
STRIDE_PB_NK
+
hk
*
STRIDE_PB_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_PB_R
)
tl
.
store
(
prob_ptrs
,
probs_T
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:])
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue
(
OUT
,
# [Nk, Hk], float32
cu_k
,
w_b
,
# [B+1], [B] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
EPS
:
tl
.
constexpr
,
# e.g., 1e-12
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
k_eff_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_eff_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
+
EPS
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_T"
:
bt
})
for
bt
in
[
32
,
64
,
128
,
256
]],
key
=
[
"KERNEL_SIZE"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_snapkv_avg_pool1d_kernel
(
IN
,
OUT
,
Lp
,
STRIDE_IN_C
,
STRIDE_IN_L
,
STRIDE_OUT_C
,
STRIDE_OUT_L
,
KERNEL_SIZE
:
tl
.
constexpr
,
PAD
:
tl
.
constexpr
,
BLOCK_T
:
tl
.
constexpr
,
):
"""
Symmetric 1D average pool on the last dimension, matching
`F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
(equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
"""
c
=
tl
.
program_id
(
0
)
t0
=
tl
.
program_id
(
1
)
*
BLOCK_T
+
tl
.
arange
(
0
,
BLOCK_T
)
mask
=
t0
<
Lp
acc
=
tl
.
zeros
([
BLOCK_T
],
dtype
=
tl
.
float32
)
for
j
in
tl
.
static_range
(
KERNEL_SIZE
):
idx
=
t0
-
PAD
+
j
valid
=
(
idx
>=
0
)
&
(
idx
<
Lp
)
ptrs
=
IN
+
c
*
STRIDE_IN_C
+
idx
*
STRIDE_IN_L
v
=
tl
.
load
(
ptrs
,
mask
=
valid
&
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
v
acc
=
acc
/
tl
.
cast
(
KERNEL_SIZE
,
tl
.
float32
)
out_ptrs
=
OUT
+
c
*
STRIDE_OUT_C
+
t0
*
STRIDE_OUT_L
tl
.
store
(
out_ptrs
,
acc
,
mask
=
mask
)
def
_snapkv_avg_pool1d_triton
(
x
:
torch
.
Tensor
,
kernel_size
:
int
)
->
torch
.
Tensor
:
"""
kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
`x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
"""
assert
x
.
dtype
==
torch
.
float32
Hk
,
G
,
Lp
=
x
.
shape
if
Lp
==
0
:
return
x
pad
=
kernel_size
//
2
x2
=
x
.
reshape
(
Hk
*
G
,
Lp
).
contiguous
()
out
=
torch
.
empty_like
(
x2
)
C
=
Hk
*
G
si_c
,
si_l
=
x2
.
stride
()
so_c
,
so_l
=
out
.
stride
()
def
grid
(
meta
):
return
(
C
,
triton
.
cdiv
(
Lp
,
meta
[
"BLOCK_T"
]))
_snapkv_avg_pool1d_kernel
[
grid
](
x2
,
out
,
Lp
,
si_c
,
si_l
,
so_c
,
so_l
,
KERNEL_SIZE
=
kernel_size
,
PAD
=
pad
,
)
return
out
.
view
(
Hk
,
G
,
Lp
)
def
_snapkv_kvpress_epilogue
(
probs_buf
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
G
:
int
,
Hk
:
int
,
kernel_size
:
int
,
)
->
None
:
"""
Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
→ mean over GQA groups → pad tail with global max of prefix scores.
"""
B
=
cu_seqlens_k
.
numel
()
-
1
for
b
in
range
(
B
):
k_beg
=
int
(
cu_seqlens_k
[
b
].
item
())
k_end
=
int
(
cu_seqlens_k
[
b
+
1
].
item
())
win
=
int
(
w
[
b
].
item
())
k_eff_end
=
k_end
-
win
if
win
<=
0
or
k_eff_end
<=
k_beg
:
continue
Lp
=
k_eff_end
-
k_beg
rows_b
=
win
*
G
p
=
probs_buf
[
k_beg
:
k_eff_end
,
:,
:
rows_b
]
# [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
x
=
p
.
view
(
Lp
,
Hk
,
win
,
G
).
mean
(
dim
=
2
)
x
=
x
.
permute
(
1
,
2
,
0
).
contiguous
()
# [Hk, G, Lp]
x
=
_snapkv_avg_pool1d_triton
(
x
,
kernel_size
)
x
=
x
.
mean
(
dim
=
1
)
seg
=
x
.
permute
(
1
,
0
).
contiguous
()
out
[
k_beg
:
k_eff_end
,
:]
=
seg
pad_val
=
seg
.
max
()
out
[
k_eff_end
:
k_end
,
:]
=
pad_val
def
query_aware_key_scores
(
q
:
torch
.
Tensor
,
# [N_q, Hq, D]
k
:
torch
.
Tensor
,
# [N_k, Hk, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1], int32
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1], int32
w
:
torch
.
Tensor
|
int
,
# [B], int32
sm_scale
:
float
=
None
,
# defaults to 1/sqrt(D)
*
,
kernel_size
:
int
=
DEFAULT_SNAPKV_KERNEL_SIZE
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
normalize
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
,
"last dim must be contiguous"
device
=
q
.
device
N_q
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
(
Hq
%
Hk
)
==
0
,
"Hq must be a multiple of Hk"
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
==
cu_seqlens_k
.
numel
()
-
1
G
=
Hq
//
Hk
if
type
(
w
)
is
int
:
max_w
=
w
w
=
torch
.
full
((
B
,),
fill_value
=
w
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
max_w
=
int
(
w
.
max
().
item
())
assert
w
.
numel
()
==
B
ROWS_MAX
=
max_w
*
G
if
ROWS_MAX
==
0
:
return
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
out
=
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
m_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
S_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
logits_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
probs_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
# strides
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
_
=
q
.
stride
()
STRIDE_K_NK
,
STRIDE_K_HK
,
_
=
k
.
stride
()
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
=
m_scratch
.
stride
()
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
=
S_scratch
.
stride
()
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
=
logits_buf
.
stride
()
STRIDE_PB_NK
,
STRIDE_PB_HK
,
STRIDE_PB_R
=
probs_buf
.
stride
()
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
out
.
stride
()
def
grid
(
META
):
return
B
,
Hk
,
triton
.
cdiv
(
ROWS_MAX
,
META
[
"BLOCK_Q"
])
_lse_and_store_logits_kernel
[
grid
](
q
,
k
,
cu_seqlens_q
,
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
ROWS_MAX
=
ROWS_MAX
,
)
_prefix_probs_kernel
[(
B
,
Hk
)](
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
probs_buf
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
STRIDE_PB_NK
=
STRIDE_PB_NK
,
STRIDE_PB_HK
=
STRIDE_PB_HK
,
STRIDE_PB_R
=
STRIDE_PB_R
,
)
_snapkv_kvpress_epilogue
(
probs_buf
,
out
,
cu_seqlens_k
,
w
,
G
,
Hk
,
kernel_size
)
if
normalize
:
_zscore_per_batch_epilogue
[(
B
,)](
out
,
cu_seqlens_k
,
w
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
=
Hk
,
EPS
=
1e-12
,
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
accum_scores
.
mul_
(
accum_blending
)
accum_scores
.
add_
(
out
)
return
accum_scores
else
:
return
out
vllm/kvprune_legacy_save/compression/snapkv_origin.py
0 → 100644
View file @
2b7160c6
import
math
from
typing
import
Optional
import
torch
import
triton
from
triton
import
language
as
tl
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
class
SnapKVCompression
(
BaseCompressionMethod
):
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
scores
=
maybe_execute_in_stream
(
query_aware_key_scores
,
q
,
k
,
context
.
cu_seqlens_q
,
context
.
cu_seqlens_k
,
w
=
32
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
bq
in
[
32
,
64
]
for
bk
in
[
32
,
64
]
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
3
,
4
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"ROWS_MAX"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_lse_and_store_logits_kernel
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
# int32 pointers
out_m
,
out_S
,
# [B, Hk, ROWS_MAX] float32
LOGITS
,
# [Nk, Hk, ROWS_MAX] float32
sm_scale
,
# float
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
ROWS_MAX
,
):
# program ids
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
# row-tile id
# batch segment bounds
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
# rows for this (b,hk)
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale
=
sm_scale
*
1.4426950408889634
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
# map row -> (q_idx, hq_local)
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
offs_d
=
tl
.
arange
(
0
,
D
)
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:]
)
q_rows
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
k_ptrs
=
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
],
other
=
0.0
)
# [BK, D]
s
=
tl
.
dot
(
q_rows
,
k_blk
.
T
)
*
qk_scale
# [BQ, BK]
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
# store into LOGITS[nk, hk, row] -> [BK, BQ]
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
kmask
[:,
None
]
&
row_mask
[
None
,
:])
# log2 streaming LSE update
cur_max
=
tl
.
max
(
s
,
1
)
# [BQ]
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
# store m,S for these rows
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
})
for
bq
in
[
16
,
32
,
64
]
for
bk
in
[
32
,
64
,
128
]
],
key
=
[
"HK"
,
"HQ"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_scores_from_logits_kernel
(
cu_k
,
w_b
,
in_m
,
in_S
,
# [B, Hk, ROWS_MAX] f32
LOGITS
,
# [Nk, Hk, ROWS_MAX] f32, base-2 logits
OUT
,
# [Nk, Hk] f32
#
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
#
DO_POOL
:
tl
.
constexpr
,
# set True to enable in-place avg pool
KPOOL
:
tl
.
constexpr
,
# kernel size for avg pool (stride=1)
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
# === scores over computed region ===
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
scores
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
row0
in
tl
.
range
(
0
,
rows_b
,
BLOCK_Q
):
r_idx
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
rmask
=
r_idx
<
rows_b
# load m, S for rows
m_ptr
=
in_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_ptr
=
in_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
m
=
tl
.
load
(
m_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
mask
=
rmask
,
other
=-
float
(
"inf"
),
)
S
=
tl
.
load
(
S_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
mask
=
rmask
,
other
=
0.0
)
valid_row
=
S
>
0
m
=
tl
.
where
(
valid_row
,
m
,
0.0
)
S
=
tl
.
where
(
valid_row
,
S
,
1.0
)
# load stored logits^T: [BK, BQ]
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
s_T
=
tl
.
load
(
log_ptrs
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:],
other
=-
float
(
"inf"
)
)
# [BK, BQ]
# probs^T = exp2(s_T - m) / S, sum over rows
probs_T
=
tl
.
math
.
exp2
(
s_T
-
m
[
None
,
:])
/
S
[
None
,
:]
probs_T
=
tl
.
where
(
valid_row
[
None
,
:],
probs_T
,
0.0
)
scores
+=
tl
.
sum
(
probs_T
,
1
)
# [BK]
if
DO_POOL
and
(
KPOOL
>
1
):
i
=
tl
.
arange
(
0
,
BLOCK_K
)[:,
None
]
j
=
tl
.
arange
(
0
,
BLOCK_K
)[
None
,
:]
band
=
(
j
<=
i
)
&
((
i
-
j
)
<
KPOOL
)
band
=
band
&
kmask
[
None
,
:]
# sum within band
sums
=
tl
.
sum
(
tl
.
where
(
band
,
scores
[
None
,
:],
0.0
),
1
)
# [BK]
denom
=
tl
.
sum
(
band
,
1
).
to
(
tl
.
float32
)
# [BK]
denom
=
tl
.
where
(
denom
>
0
,
denom
,
1.0
)
scores
=
sums
/
denom
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
scores
,
mask
=
kmask
)
pad_beg
=
k_eff_end
pad_end
=
k_end
if
pad_end
>
pad_beg
:
for
ks
in
tl
.
range
(
pad_beg
,
pad_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
pad_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
tl
.
full
([
BLOCK_K
],
float
(
"inf"
),
dtype
=
tl
.
float32
),
mask
=
kmask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue
(
OUT
,
# [Nk, Hk], float32
cu_k
,
w_b
,
# [B+1], [B] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
EPS
:
tl
.
constexpr
,
# e.g., 1e-12
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
k_eff_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_eff_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
+
EPS
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
def
query_aware_key_scores
(
q
:
torch
.
Tensor
,
# [N_q, Hq, D]
k
:
torch
.
Tensor
,
# [N_k, Hk, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1], int32
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1], int32
w
:
torch
.
Tensor
|
int
,
# [B], int32
sm_scale
:
float
=
None
,
# defaults to 1/sqrt(D)
*
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
normalize
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
,
"last dim must be contiguous"
device
=
q
.
device
N_q
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
(
Hq
%
Hk
)
==
0
,
"Hq must be a multiple of Hk"
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
==
cu_seqlens_k
.
numel
()
-
1
G
=
Hq
//
Hk
if
type
(
w
)
is
int
:
max_w
=
w
w
=
torch
.
full
((
B
,),
fill_value
=
w
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
max_w
=
int
(
w
.
max
().
item
())
assert
w
.
numel
()
==
B
ROWS_MAX
=
max_w
*
G
if
ROWS_MAX
==
0
:
return
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
out
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
m_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
S_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
logits_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
# strides
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
_
=
q
.
stride
()
STRIDE_K_NK
,
STRIDE_K_HK
,
_
=
k
.
stride
()
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
=
m_scratch
.
stride
()
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
=
S_scratch
.
stride
()
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
=
logits_buf
.
stride
()
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
out
.
stride
()
def
grid
(
META
):
return
B
,
Hk
,
triton
.
cdiv
(
ROWS_MAX
,
META
[
"BLOCK_Q"
])
_lse_and_store_logits_kernel
[
grid
](
q
,
k
,
cu_seqlens_q
,
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
ROWS_MAX
=
ROWS_MAX
,
)
_scores_from_logits_kernel
[(
B
,
Hk
)](
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
out
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
STRIDE_OUT_NK
=
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
STRIDE_OUT_HK
,
DO_POOL
=
True
,
KPOOL
=
5
,
)
if
normalize
:
_zscore_per_batch_epilogue
[(
B
,)](
out
,
cu_seqlens_k
,
w
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
=
Hk
,
EPS
=
1e-12
,
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
accum_scores
.
mul_
(
accum_blending
)
accum_scores
.
add_
(
out
)
return
accum_scores
else
:
return
out
vllm/kvprune_legacy_save/config/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Engine / sampling / kernel constants (compactor-compatible)."""
from
vllm.kvprune.config.constants
import
RESERVED_BATCH
,
TRITON_RESERVED_BATCH
__all__
=
[
"RESERVED_BATCH"
,
"TRITON_RESERVED_BATCH"
]
vllm/kvprune_legacy_save/config/constants.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
RESERVED_BATCH
=
0
# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
TRITON_RESERVED_BATCH
=
RESERVED_BATCH
vllm/kvprune_legacy_save/config/engine_config.py
0 → 100644
View file @
2b7160c6
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
List
,
Optional
from
transformers
import
AutoConfig
class
AttentionBackend
(
Enum
):
"""Legacy coarse backend toggle (prefer :class:`KvpruneAttentionSchedule`)."""
FLASH_ATTENTION
=
auto
()
COMPACTOR_TRITON
=
auto
()
class
KvpruneAttentionSchedule
(
Enum
):
"""FlashAttention vs Triton split for prefill / decode (KV **writes** stay Triton)."""
# Default: FA varlen prefill; decode uses ``head_sparse_decode_attention`` (Triton).
FA_PREFILL_TRITON_DECODE
=
auto
()
# Prefill attention uses ``causal_sparse_varlen_with_cache`` (Triton); decode Triton.
TRITON_PREFILL_TRITON_DECODE
=
auto
()
# "PDFA": FA prefill + FA decode; paged KV **storage** (incl. pruned top-k) unchanged.
PDFA
=
auto
()
@
dataclass
class
LLMConfig
:
"""Configuration for the :class:`LLM` engine.
Parameters
----------
model : str
Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
a local model name that can be resolved by
:func:`transformers.AutoConfig.from_pretrained`.
path : str, optional
Local directory containing the model weights. If ``None``, the engine
will attempt to resolve a local snapshot for ``model`` using
:func:`huggingface_hub.snapshot_download`.
max_num_seqs : int, default 256
Upper bound on the number of concurrent batches that the scheduler and
KV-cache manager are allowed to handle. This affects the size of the
page table and some internal buffers.
max_model_len : int, default 40960
Maximum context length (in tokens) that the engine will allocate KV cache
and CUDA graphs for. During initialization this value is clamped to
``hf_config.max_position_embeddings`` for the chosen model.
gpu_memory_utilization : float, default 0.9
Fraction of the total GPU memory that may be used for KV cache and model
activations. Values should be in ``(0, 1]``. If this budget is too small,
the KV-cache manager may raise an error at warmup time due
to insufficient memory.
tensor_parallel_size : int, default 1
Number of tensor-parallel workers to shard the model
across. Must be between 1 and 8, and must evenly divide the model's
number of key/value heads.
enforce_eager : bool, default False
If ``True``, disable CUDA graph capture and always run the model in
eager mode during decoding. This reduces throughput. When ``False``,
the engine will capture and reuse CUDA graphs for supported
batch sizes and sequence lengths.
hf_config : transformers.AutoConfig, optional
Pre-loaded Hugging Face configuration for the model. If ``None``,
it will then be populated automatically based on ``model``.
eos : int, default -1
Primary stop token id (warmup / single-id paths). If ``-1``, the
:class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
tokenizer.
eos_token_ids : list of int, optional
All token ids that terminate generation (e.g. HF tokenizers may expose
``eos_token_id`` as a list for chat models). If ``None``, inferred in
:class:`LLM` from the tokenizer and model type.
kvcache_page_size : int, default 128
Number of tokens stored in a single KV-cache page. Smaller pages improve
allocation flexibility but increase page-table overhead; larger pages
reduce overhead but have coarser granularity.
leverage_sketch_size : int, default 48
Sketch dimension used by the Compactor leverage-score estimator.
attention_schedule : KvpruneAttentionSchedule, default FA_PREFILL_TRITON_DECODE
Which **attention** implementation runs on prefill vs decode. KV **writes**
(``prefill_store_*``, ``decode_store_kv``, pruned top-k) always use the
existing Triton store kernels. Env ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` uses
short names: ``fa_triton`` (default), ``pdtriton``, ``pdfa``. Enum values:
``FA_PREFILL_TRITON_DECODE`` — FA prefill, Triton decode;
``TRITON_PREFILL_TRITON_DECODE`` — Triton prefill + decode;
``PDFA`` — FA prefill + FA decode (still Triton KV I/O).
attention_backend : AttentionBackend, optional
Deprecated. Ignored if ``attention_schedule`` is set; otherwise mapped
for backward compatibility.
"""
model
:
str
path
:
Optional
[
str
]
=
None
nccl_port
:
Optional
[
int
]
=
1218
max_num_seqs
:
int
=
256
max_model_len
:
int
=
40960
gpu_memory_utilization
:
float
=
0.9
tensor_parallel_size
:
int
=
1
enforce_eager
:
bool
=
False
hf_config
:
AutoConfig
|
None
=
None
eos
:
int
=
-
1
eos_token_ids
:
Optional
[
List
[
int
]]
=
None
kvcache_page_size
:
int
=
128
leverage_sketch_size
:
int
=
48
attention_schedule
:
KvpruneAttentionSchedule
=
(
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
)
attention_backend
:
AttentionBackend
|
None
=
None
show_progress_bar
:
bool
=
True
def
__post_init__
(
self
):
if
self
.
attention_backend
is
not
None
:
if
self
.
attention_backend
==
AttentionBackend
.
FLASH_ATTENTION
:
self
.
attention_schedule
=
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
else
:
self
.
attention_schedule
=
(
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
)
if
self
.
path
is
not
None
and
not
os
.
path
.
isdir
(
self
.
path
):
raise
NotADirectoryError
(
f
"Engine config dir
{
self
.
path
}
does not exist"
)
if
self
.
tensor_parallel_size
<=
0
or
self
.
tensor_parallel_size
>
8
:
assert
1
<=
self
.
tensor_parallel_size
<=
8
raise
ValueError
(
"tensor_parallel_size must be >= 1 and <= 8"
)
if
self
.
hf_config
is
None
:
self
.
hf_config
=
AutoConfig
.
from_pretrained
(
self
.
model
)
self
.
max_model_len
=
min
(
self
.
max_model_len
,
self
.
hf_config
.
max_position_embeddings
)
vllm/kvprune_legacy_save/config/sampling_params.py
0 → 100644
View file @
2b7160c6
from
dataclasses
import
dataclass
@
dataclass
class
SamplingParams
:
temperature
:
float
=
1.0
max_new_tokens
:
int
=
256
def
__post_init__
(
self
):
if
self
.
temperature
<
0
:
raise
ValueError
(
"Temperature cannot be negative"
)
vllm/kvprune_legacy_save/core/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Core: compactor ``LLMEngine`` stack (``llm_engine``, ``scheduler``, …) plus helpers
(``runtime``, ``flash_integration``, ``block_budget``) used **inside** the compactor path.
v1 does not import these; use :meth:`vllm.LLM.generate` with ``compression=`` for the
``LLM`` + compactor integration.
"""
from
vllm.kvprune.core.block_budget
import
(
TailReclaimHint
,
build_tail_reclaim_hint
,
tail_blocks_if_logical_shorter
,
)
from
vllm.kvprune.core.compression_bridge
import
(
VALID_ALIASES_FOR_SAMPLING
,
compression_method_id_to_enum
,
compression_method_str_to_id
,
)
from
vllm.kvprune.core.flash_integration
import
(
do_kv_cache_update_kv_prune
,
merge_seq_lens_with_kv_prune
,
)
from
vllm.kvprune.core.runtime
import
(
KVPruneForwardState
,
build_kv_prune_forward_state
,
get_kv_prune_state
,
layer_index_from_layer_name
,
)
__all__
=
[
"KVPruneForwardState"
,
"TailReclaimHint"
,
"VALID_ALIASES_FOR_SAMPLING"
,
"build_kv_prune_forward_state"
,
"build_tail_reclaim_hint"
,
"compression_method_id_to_enum"
,
"compression_method_str_to_id"
,
"do_kv_cache_update_kv_prune"
,
"get_kv_prune_state"
,
"layer_index_from_layer_name"
,
"merge_seq_lens_with_kv_prune"
,
"tail_blocks_if_logical_shorter"
,
]
vllm/kvprune_legacy_save/core/block_budget.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Block budget helpers for compactor KV pruning (logical vs physical length).
Used by the **compactor** ``LLMEngine`` path (``PagedKVCache`` / logical lengths),
not by v1's scheduler. The helpers compare logical KV length to a physical token
count and return how many full tail blocks can be reclaimed when logical shrinks.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
@
dataclass
(
frozen
=
True
)
class
TailReclaimHint
:
"""How many tail blocks could be freed if logical KV shrinks below allocation."""
request_id
:
str
allocated_tokens
:
int
logical_tokens
:
int
block_size
:
int
reclaimable_tail_blocks
:
int
def
tail_blocks_if_logical_shorter
(
allocated_tokens
:
int
,
logical_tokens
:
int
,
block_size
:
int
,
)
->
int
:
"""Return count of fully-unused tail blocks when ``logical < allocated``.
Block-granular: only counts whole blocks past the last block that still
contains a retained logical token index.
"""
if
block_size
<=
0
:
return
0
if
logical_tokens
>=
allocated_tokens
:
return
0
# Last logical token occupies block index floor((logical-1)/bs) if logical>0
if
logical_tokens
<=
0
:
return
(
allocated_tokens
+
block_size
-
1
)
//
block_size
last_logical_block
=
(
logical_tokens
-
1
)
//
block_size
last_alloc_block
=
(
allocated_tokens
-
1
)
//
block_size
return
max
(
0
,
last_alloc_block
-
last_logical_block
)
def
build_tail_reclaim_hint
(
request_id
:
str
,
allocated_tokens
:
int
,
logical_tokens
:
int
,
block_size
:
int
,
)
->
TailReclaimHint
:
n
=
tail_blocks_if_logical_shorter
(
allocated_tokens
,
logical_tokens
,
block_size
)
return
TailReclaimHint
(
request_id
=
request_id
,
allocated_tokens
=
allocated_tokens
,
logical_tokens
=
logical_tokens
,
block_size
=
block_size
,
reclaimable_tail_blocks
=
n
,
)
__all__
=
[
"TailReclaimHint"
,
"build_tail_reclaim_hint"
,
"tail_blocks_if_logical_shorter"
,
]
vllm/kvprune_legacy_save/core/compression_bridge.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Map compression method strings (e.g. from :class:`~vllm.kvprune.integration.CompressionParams`) to kvprune GPU / enum IDs."""
from
__future__
import
annotations
from
vllm.kvprune.compression.compression_config
import
CompressionMethod
# IDs stored on device [num_reqs_padded] (int32). Order is stable for kernels.
COMPRESSION_METHOD_ID_NONE
=
0
COMPRESSION_METHOD_ID_CRITICALADAKV
=
1
COMPRESSION_METHOD_ID_COMPACTOR
=
2
COMPRESSION_METHOD_ID_SNAPKV
=
3
# Aliases accepted for method strings (case-insensitive after strip).
VALID_ALIASES_FOR_SAMPLING
:
frozenset
[
str
]
=
frozenset
(
{
"none"
,
"criticaladakv"
,
"compactor"
,
"snapkv"
}
)
_STR_TO_ID
:
dict
[
str
,
int
]
=
{
"none"
:
COMPRESSION_METHOD_ID_NONE
,
"criticaladakv"
:
COMPRESSION_METHOD_ID_CRITICALADAKV
,
"compactor"
:
COMPRESSION_METHOD_ID_COMPACTOR
,
"snapkv"
:
COMPRESSION_METHOD_ID_SNAPKV
,
}
_ID_TO_COMPRESSION_METHOD
:
dict
[
int
,
CompressionMethod
]
=
{
COMPRESSION_METHOD_ID_NONE
:
CompressionMethod
.
NONE
,
COMPRESSION_METHOD_ID_CRITICALADAKV
:
CompressionMethod
.
CRITICALADAKV
,
COMPRESSION_METHOD_ID_COMPACTOR
:
CompressionMethod
.
COMPACTOR
,
COMPRESSION_METHOD_ID_SNAPKV
:
CompressionMethod
.
SNAPKV
,
}
def
compression_method_str_to_id
(
s
:
str
)
->
int
:
"""Normalize and map user string to a stable int id (0..3)."""
key
=
(
s
or
"none"
).
strip
().
lower
()
if
key
not
in
_STR_TO_ID
:
raise
ValueError
(
f
"Unknown compression_method
{
s
!
r
}
; expected one of "
f
"
{
sorted
(
VALID_ALIASES_FOR_SAMPLING
)
}
"
)
return
_STR_TO_ID
[
key
]
def
compression_method_id_to_enum
(
method_id
:
int
)
->
CompressionMethod
:
if
method_id
not
in
_ID_TO_COMPRESSION_METHOD
:
return
CompressionMethod
.
NONE
return
_ID_TO_COMPRESSION_METHOD
[
method_id
]
__all__
=
[
"COMPRESSION_METHOD_ID_NONE"
,
"COMPRESSION_METHOD_ID_CRITICALADAKV"
,
"COMPRESSION_METHOD_ID_COMPACTOR"
,
"COMPRESSION_METHOD_ID_SNAPKV"
,
"VALID_ALIASES_FOR_SAMPLING"
,
"compression_method_id_to_enum"
,
"compression_method_str_to_id"
,
]
vllm/kvprune_legacy_save/core/flash_integration.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""FlashAttention + KV cache hooks for kvprune."""
from
__future__
import
annotations
import
torch
from
vllm.kvprune.core.runtime
import
KVPruneForwardState
,
get_kv_prune_state
_RATIO_ONE
=
1.0
-
1e-6
def
merge_seq_lens_with_kv_prune
(
base_seq_lens
:
torch
.
Tensor
,
layer_name
:
str
,
max_query_len
:
int
,
)
->
torch
.
Tensor
:
"""Blend scheduler seq_lens with per-layer logical lengths when pruning."""
state
=
get_kv_prune_state
()
if
state
is
None
:
return
base_seq_lens
# Prefill: only scheduler lengths are reliable unless compactor store ran for
# every layer (try_prefill_kv_store); when pruning is requested but ineligible
# (e.g. unsupported dtype), logical buffers may still be zero — do not override.
if
max_query_len
>
1
:
return
base_seq_lens
layer_idx
=
_layer_idx
(
layer_name
)
num_reqs
=
state
.
num_reqs
comp
=
state
.
compression_ratio_gpu
[:
num_reqs
]
logical
=
state
.
logical_seq_lens_gpu
[
layer_idx
,
:
num_reqs
]
if
logical
.
dim
()
==
2
:
logical
=
logical
.
max
(
dim
=-
1
).
values
out
=
base_seq_lens
.
clone
()
use_logical
=
comp
<
_RATIO_ONE
out
[:
num_reqs
]
=
torch
.
where
(
use_logical
,
logical
.
to
(
out
.
dtype
),
base_seq_lens
[:
num_reqs
],
)
return
out
def
_layer_idx
(
layer_name
:
str
)
->
int
:
from
vllm.kvprune.core.runtime
import
layer_index_from_layer_name
return
layer_index_from_layer_name
(
layer_name
)
def
do_kv_cache_update_kv_prune
(
layer
:
torch
.
nn
.
Module
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
reshape_and_cache_flash
,
kv_cache_dtype
:
str
,
)
->
bool
:
"""If kvprune handles this step, return True (caller skips default path)."""
state
=
get_kv_prune_state
()
if
state
is
None
:
return
False
layer_idx
=
_layer_idx
(
layer
.
layer_name
)
num_reqs
=
state
.
num_reqs
if
state
.
is_prefill
:
from
vllm.kvprune.compression.prefill
import
try_prefill_kv_store
if
try_prefill_kv_store
(
layer
,
key
,
value
,
kv_cache
):
return
True
return
False
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
comp
=
state
.
compression_ratio_gpu
[:
num_reqs
]
mask
=
(
comp
<
_RATIO_ONE
).
to
(
torch
.
int32
)
layer_buf
=
state
.
logical_seq_lens_gpu
[
layer_idx
,
:
num_reqs
]
if
layer_buf
.
dim
()
==
2
:
layer_buf
+=
mask
.
unsqueeze
(
-
1
)
else
:
layer_buf
+=
mask
return
True
vllm/kvprune_legacy_save/core/llm_engine.py
0 → 100644
View file @
2b7160c6
from
__future__
import
annotations
import
atexit
import
inspect
import
logging
from
pathlib
import
Path
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
SequenceCompressionParams
,
)
from
vllm.kvprune.config.engine_config
import
LLMConfig
from
vllm.kvprune.config.sampling_params
import
SamplingParams
from
vllm.kvprune.core.model_runner
import
ModelRunner
from
vllm.kvprune.models
import
MODEL_REGISTRY
from
vllm.kvprune.utils.sequence
import
Sequence
from
transformers
import
AutoTokenizer
logger
=
logging
.
getLogger
(
__name__
)
PromptLike
=
Union
[
str
,
List
[
int
]]
def
_infer_stop_token_ids
(
tokenizer
,
hf_config
)
->
list
[
int
]:
"""
Build the set of token ids that should end generation.
Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
The engine must not compare generated ids to that list as a single ``int``;
see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
Qwen chat uses ``</think>`` (im_end) as the assistant turn boundary; include it
when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
avoid loose substring matches like ``
\"
end
\"
`` that can tag unrelated tokens.
"""
raw
=
tokenizer
.
eos_token_id
ids
:
list
[
int
]
=
[]
if
isinstance
(
raw
,
(
list
,
tuple
)):
ids
.
extend
(
int
(
x
)
for
x
in
raw
)
elif
raw
is
not
None
:
ids
.
append
(
int
(
raw
))
unk_id
=
getattr
(
tokenizer
,
"unk_token_id"
,
None
)
def
_maybe_add_tid
(
tid
:
int
)
->
None
:
if
not
isinstance
(
tid
,
int
)
or
tid
<
0
:
return
if
unk_id
is
not
None
and
tid
==
unk_id
:
return
if
tid
not
in
ids
:
ids
.
append
(
tid
)
model_type
=
getattr
(
hf_config
,
"model_type"
,
None
)
if
model_type
in
(
"qwen2"
,
"qwen3"
,
"qwen2_moe"
,
"qwen3_moe"
):
enc
=
getattr
(
tokenizer
,
"added_tokens_encoder"
,
None
)
if
isinstance
(
enc
,
dict
):
for
key
,
tid
in
enc
.
items
():
if
isinstance
(
key
,
str
)
and
"im_end"
in
key
:
_maybe_add_tid
(
int
(
tid
))
for
extra
in
getattr
(
tokenizer
,
"additional_special_tokens"
,
[])
or
[]:
if
not
isinstance
(
extra
,
str
)
or
"im_end"
not
in
extra
:
continue
try
:
tid
=
tokenizer
.
convert_tokens_to_ids
(
extra
)
except
(
TypeError
,
ValueError
,
KeyError
):
continue
_maybe_add_tid
(
tid
)
if
not
ids
:
raise
ValueError
(
"Could not infer stop token ids from the tokenizer; set "
"LLMConfig(eos_token_ids=[...]) explicitly."
)
return
ids
def
_merge_apply_chat_template_kwargs
(
tokenizer
,
user_kwargs
:
Optional
[
dict
[
str
,
Any
]],
)
->
dict
[
str
,
Any
]:
"""
Merge user kwargs with defaults for HF chat templates that support them.
Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
the first generated token continues the assistant turn; without it, output
can repeat punctuation / template fragments. `enable_thinking=False` avoids
the Qwen3 reasoning channel when the tokenizer supports it.
"""
out
=
dict
(
user_kwargs
or
{})
try
:
sig
=
inspect
.
signature
(
tokenizer
.
apply_chat_template
)
except
(
TypeError
,
ValueError
):
return
out
if
"add_generation_prompt"
in
sig
.
parameters
and
"add_generation_prompt"
not
in
out
:
out
[
"add_generation_prompt"
]
=
True
if
"enable_thinking"
in
sig
.
parameters
and
"enable_thinking"
not
in
out
:
out
[
"enable_thinking"
]
=
False
return
out
def
_runner_entry
(
config
:
LLMConfig
,
rank
:
int
,
evt
):
runner
=
None
try
:
runner
=
ModelRunner
(
config
,
rank
,
evt
)
runner
.
loop
()
except
Exception
as
e
:
logging
.
exception
(
f
"Rank
{
rank
}
:
{
repr
(
e
)
}
"
)
finally
:
if
runner
is
not
None
:
runner
.
exit
()
class
LLMEngine
:
"""High-level engine coordinating model runners and scheduling"""
def
__init__
(
self
,
config
:
LLMConfig
,
external_model
:
nn
.
Module
|
None
=
None
):
self
.
config
=
config
if
self
.
config
.
hf_config
.
model_type
not
in
MODEL_REGISTRY
:
raise
ValueError
(
f
"Unknown model
{
self
.
config
.
model
}
"
)
if
config
.
path
is
None
:
# Local directory: use it directly (no Hub round-trip).
try
:
mp
=
Path
(
config
.
model
)
if
mp
.
is_dir
()
and
(
mp
/
"config.json"
).
is_file
():
self
.
config
.
path
=
str
(
mp
.
resolve
())
logger
.
info
(
"Using local model directory for tokenizer: %s"
,
self
.
config
.
path
)
except
OSError
:
pass
if
config
.
path
is
None
:
from
huggingface_hub
import
snapshot_download
# Hub repo id: allow downloading missing shards/tokenizer files when cache
# is incomplete (local_files_only=False). Local dirs are handled above.
self
.
config
.
path
=
snapshot_download
(
repo_id
=
config
.
model
,
local_files_only
=
False
,
)
logger
.
info
(
"Resolved Hugging Face snapshot for %s @ %s"
,
self
.
config
.
model
,
self
.
config
.
path
,
)
assert
self
.
config
.
path
is
not
None
_trust
=
bool
(
getattr
(
self
.
config
.
hf_config
,
"trust_remote_code"
,
False
))
# Always load tokenizer from the resolved on-disk tree so we do not re-hit
# the Hub with the repo id (can re-download tokenizer / LFS shards).
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
config
.
path
,
use_fast
=
True
,
trust_remote_code
=
_trust
,
)
if
self
.
config
.
eos_token_ids
is
None
:
if
self
.
config
.
eos
!=
-
1
:
self
.
config
.
eos_token_ids
=
[
int
(
self
.
config
.
eos
)]
else
:
self
.
config
.
eos_token_ids
=
_infer_stop_token_ids
(
self
.
tokenizer
,
self
.
config
.
hf_config
)
else
:
self
.
config
.
eos_token_ids
=
[
int
(
x
)
for
x
in
self
.
config
.
eos_token_ids
]
self
.
config
.
eos_token_ids
=
sorted
(
set
(
self
.
config
.
eos_token_ids
))
if
self
.
config
.
eos
==
-
1
:
self
.
config
.
eos
=
int
(
self
.
config
.
eos_token_ids
[
0
])
else
:
self
.
config
.
eos
=
int
(
self
.
config
.
eos
)
if
self
.
config
.
eos
not
in
self
.
config
.
eos_token_ids
:
self
.
config
.
eos_token_ids
=
sorted
(
self
.
config
.
eos_token_ids
+
[
self
.
config
.
eos
]
)
if
external_model
is
not
None
and
int
(
self
.
config
.
tensor_parallel_size
)
!=
1
:
raise
ValueError
(
"external_model (shared-weight compactor path) only supports "
"tensor_parallel_size=1"
)
self
.
ps
=
[]
world_size
=
int
(
self
.
config
.
tensor_parallel_size
)
self
.
events
=
[]
if
world_size
>
1
:
ctx
=
mp
.
get_context
(
"spawn"
)
for
r
in
range
(
1
,
world_size
):
event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
_runner_entry
,
args
=
(
self
.
config
,
r
,
event
),
daemon
=
True
,
)
p
.
start
()
self
.
ps
.
append
(
p
)
self
.
events
.
append
(
event
)
self
.
master_model_runner
=
ModelRunner
(
self
.
config
,
rank
=
0
,
peer_events
=
self
.
events
,
external_model
=
external_model
,
)
atexit
.
register
(
self
.
exit
)
def
exit
(
self
):
if
getattr
(
self
,
"_exited"
,
False
):
return
self
.
_exited
=
True
runner
=
getattr
(
self
,
"master_model_runner"
,
None
)
if
runner
is
not
None
:
try
:
runner
.
exit
()
except
Exception
:
logger
.
exception
(
"Failed to exit master ModelRunner cleanly"
)
for
p
in
self
.
ps
:
if
p
.
is_alive
():
p
.
terminate
()
p
.
join
(
timeout
=
1.0
)
if
hasattr
(
self
,
"events"
):
self
.
events
.
clear
()
def
tokenize_prompt
(
self
,
prompt
:
PromptLike
,
**
tokenizer_kwargs
)
->
List
[
int
]:
"""
Turn a raw prompt into token IDs.
"""
if
isinstance
(
prompt
,
str
):
return
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
)[
"input_ids"
]
else
:
return
list
(
prompt
)
def
detokenize_prompt
(
self
,
sequences
:
List
[
Sequence
],
**
detokenizer_kwargs
)
->
List
[
str
]:
"""
Turn completed Sequences into strings.
"""
defaults
:
dict
[
str
,
Any
]
=
{
"skip_special_tokens"
:
True
}
merged
=
{
**
defaults
,
**
detokenizer_kwargs
}
return
self
.
tokenizer
.
batch_decode
(
[
s
.
completion_token_ids
for
s
in
sequences
],
**
merged
)
def
_build_sequences
(
self
,
prompts
:
List
[
PromptLike
]
|
PromptLike
,
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
per_sequence_compression_params
:
Optional
[
SequenceCompressionParams
|
List
[
SequenceCompressionParams
]
]
=
None
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
List
[
Sequence
]:
"""
Build Sequence objects from prompts, sampling params, and optional
per-sequence compression parameters.
"""
tokenizer_kwargs
=
{}
if
tokenizer_kwargs
is
None
else
tokenizer_kwargs
if
not
isinstance
(
prompts
,
list
):
prompts
=
[
prompts
]
if
isinstance
(
sampling_params
,
SamplingParams
):
sampling_params_list
:
List
[
SamplingParams
]
=
[
sampling_params
]
*
len
(
prompts
)
else
:
sampling_params_list
=
sampling_params
assert
len
(
sampling_params_list
)
==
len
(
prompts
),
(
"sampling_params list must match prompts length"
)
if
per_sequence_compression_params
is
None
:
compression_params_list
:
List
[
SequenceCompressionParams
]
=
[
SequenceCompressionParams
(
1.0
)
for
_
in
prompts
]
elif
isinstance
(
per_sequence_compression_params
,
SequenceCompressionParams
):
compression_params_list
=
[
per_sequence_compression_params
]
*
len
(
prompts
)
else
:
# list-like
assert
len
(
per_sequence_compression_params
)
==
len
(
prompts
),
(
"per_sequence_compression_params list must match prompts length"
)
compression_params_list
=
list
(
per_sequence_compression_params
)
seqs
:
List
[
Sequence
]
=
[]
for
prompt
,
sparams
,
cparams
in
zip
(
prompts
,
sampling_params_list
,
compression_params_list
):
token_ids
=
self
.
tokenize_prompt
(
prompt
,
**
tokenizer_kwargs
)
if
cparams
.
protected_first_tokens
+
cparams
.
protected_last_tokens
>=
len
(
token_ids
):
cparams
.
compression_ratio
=
1.0
seqs
.
append
(
Sequence
(
prompt_token_ids
=
token_ids
,
sampling_params
=
sparams
,
compression_params
=
cparams
,
)
)
return
seqs
def
generate
(
self
,
prompts
:
List
[
PromptLike
]
|
PromptLike
,
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
batch_compression_params
:
BatchCompressionParams
,
*
,
per_sequence_compression_params
:
Union
[
List
[
SequenceCompressionParams
],
SequenceCompressionParams
]
=
None
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
detokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
return_sequences
:
bool
=
False
,
)
->
List
[
str
]
|
tuple
[
List
[
str
],
List
[
Sequence
]]:
"""
Accept prompts and return completed Sequences.
Args:
:param prompts:
Single prompt or list of prompts, each either a raw text prompt,
or pre-tokenized input IDs.
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Compression settings for this batch.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
prompt in the batch.
:param tokenizer_kwargs:
Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
string prompts.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[Sequence]:
One Sequence per input prompt, with `completion_token_ids`
filled in after generation.
"""
tokenizer_kwargs
=
{}
if
tokenizer_kwargs
is
None
else
tokenizer_kwargs
detokenizer_kwargs
=
{}
if
detokenizer_kwargs
is
None
else
detokenizer_kwargs
seqs
=
self
.
_build_sequences
(
prompts
,
sampling_params
=
sampling_params
,
per_sequence_compression_params
=
per_sequence_compression_params
,
tokenizer_kwargs
=
tokenizer_kwargs
,
)
self
.
master_model_runner
.
generate
(
seqs
,
batch_compression_params
)
output_strings
=
self
.
detokenize_prompt
(
seqs
,
**
detokenizer_kwargs
)
if
return_sequences
:
return
output_strings
,
seqs
return
output_strings
def
generate_chat
(
self
,
messages_batch
:
List
[
List
[
dict
]],
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
batch_compression_params
:
BatchCompressionParams
,
per_sequence_compression_params
:
Union
[
SequenceCompressionParams
,
List
[
SequenceCompressionParams
]
],
*
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
detokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
return_sequences
:
bool
=
False
,
)
->
List
[
str
]
|
tuple
[
List
[
str
],
List
[
Sequence
]]:
"""
Convenience API for chat-style prompts using HF `apply_chat_template`.
Args:
:param messages_batch:
List of conversations, where each conversation is a list of
message dicts like:
{"role": "system" | "user" | "assistant", "content": str}
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Batch Level compression settings. Can set compression_method.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
conversation in the batch.
:param tokenizer_kwargs:
Passed through to `tokenizer.apply_chat_template`.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[str] or tuple[List[str], List[Sequence]]:
One string per conversation.
"""
prompts_token_ids
:
List
[
List
[
int
]]
=
[]
tokenizer_kwargs
=
_merge_apply_chat_template_kwargs
(
self
.
tokenizer
,
tokenizer_kwargs
)
detokenizer_kwargs
=
{}
if
detokenizer_kwargs
is
None
else
detokenizer_kwargs
for
messages
in
messages_batch
:
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
True
,
**
tokenizer_kwargs
,
)
if
hasattr
(
input_ids
,
"tolist"
):
input_ids
=
input_ids
.
tolist
()
prompts_token_ids
.
append
(
input_ids
)
return
self
.
generate
(
prompts_token_ids
,
sampling_params
=
sampling_params
,
batch_compression_params
=
batch_compression_params
,
per_sequence_compression_params
=
per_sequence_compression_params
,
tokenizer_kwargs
=
tokenizer_kwargs
,
detokenizer_kwargs
=
detokenizer_kwargs
,
return_sequences
=
return_sequences
,
)
def
generate_from_sequences
(
self
,
seqs
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
,
)
->
List
[
Sequence
]:
"""
Args:
:param seqs:
List of Sequence instances
:param batch_compression_params:
Compression settings.
Returns:
:return List[Sequence]:
Same list, mutated in-place with completions.
"""
self
.
master_model_runner
.
generate
(
seqs
,
batch_compression_params
)
return
seqs
vllm/kvprune_legacy_save/core/memory_manager.py
0 → 100644
View file @
2b7160c6
import
logging
import
os
from
typing
import
Iterable
,
List
,
Optional
import
torch
from
vllm.kvprune.config.engine_config
import
LLMConfig
from
vllm.kvprune.kv_cache.page_table
import
KVAllocationStatus
,
PagedKVCache
from
vllm.kvprune.utils.tp_utils
import
kv_heads_shard_divisor
from
torch
import
nn
logger
=
logging
.
getLogger
(
__name__
)
class
KVCacheManager
:
def
__init__
(
self
,
rank
:
int
,
config
:
LLMConfig
,
*
,
device
:
str
|
None
=
None
,
):
super
().
__init__
()
hf_config
=
config
.
hf_config
self
.
rank
=
rank
self
.
gpu_frac
=
config
.
gpu_memory_utilization
self
.
page_size
=
config
.
kvcache_page_size
self
.
world_size
=
config
.
tensor_parallel_size
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_model_len
=
config
.
max_model_len
self
.
num_layers
=
hf_config
.
num_hidden_layers
self
.
model_dtype
=
hf_config
.
torch_dtype
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
self
.
max_pages_per_batch
=
(
self
.
max_model_len
+
self
.
page_size
-
1
)
//
self
.
page_size
_ws
=
kv_heads_shard_divisor
()
self
.
num_kv_heads
=
hf_config
.
num_key_value_heads
//
_ws
assert
hf_config
.
num_key_value_heads
%
_ws
==
0
,
(
"tensor-parallel world size needs to divide num_kv_heads"
)
self
.
_cache_device
=
device
if
device
is
not
None
else
f
"cuda:
{
self
.
rank
}
"
self
.
num_pages
=
None
self
.
paged_cache
:
Optional
[
PagedKVCache
]
=
None
self
.
max_batched_tokens
=
None
self
.
seq_id_to_batch
=
{}
def
allocate_sequences
(
self
,
seq_ids
:
List
[
int
],
max_positions
:
List
[
int
]
)
->
(
bool
,
Optional
[
torch
.
Tensor
]):
batch_mapping
=
[]
for
seq_id
,
len_to_alloc
in
zip
(
seq_ids
,
max_positions
):
if
seq_id
not
in
self
.
seq_id_to_batch
:
batch_id
=
self
.
paged_cache
.
new_batch
()
if
batch_id
is
None
:
logger
.
warning
(
"Failed to allocate batch!"
)
return
False
,
None
self
.
seq_id_to_batch
[
seq_id
]
=
int
(
batch_id
)
batch_mapping
.
append
(
self
.
seq_id_to_batch
[
seq_id
])
if
(
alloc_status
:
=
self
.
paged_cache
.
reserve_tokens
(
self
.
seq_id_to_batch
[
seq_id
],
len_to_alloc
)
)
!=
KVAllocationStatus
.
SUCCESS
:
logger
.
warning
(
f
"Failed to allocate pages (
{
alloc_status
}
)!"
)
return
False
,
None
batch_mapping
=
torch
.
as_tensor
(
batch_mapping
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
True
,
batch_mapping
def
free_sequences
(
self
,
seq_ids
:
Iterable
[
int
]):
for
seq_id
in
seq_ids
:
global_batch_id
=
self
.
seq_id_to_batch
.
pop
(
seq_id
,
None
)
self
.
paged_cache
.
free_batch
(
global_batch_id
)
def
init_cache
(
self
,
model
:
nn
.
Module
):
self
.
num_pages
=
self
.
get_num_pages
(
self
.
gpu_frac
,
self
.
max_pages_per_batch
)
self
.
paged_cache
=
PagedKVCache
(
num_layers
=
self
.
num_layers
,
H_kv
=
self
.
num_kv_heads
,
head_dim
=
self
.
head_dim
,
page_size
=
self
.
page_size
,
num_pages
=
int
(
self
.
num_pages
),
max_num_batches
=
self
.
max_num_batches
,
device
=
self
.
_cache_device
,
dtype
=
self
.
model_dtype
,
max_logical_pages_per_head
=
int
(
self
.
max_pages_per_batch
),
)
self
.
_assign_cache_to_layers
(
model
)
def
_assign_cache_to_layers
(
self
,
model
)
->
None
:
for
layer_index
,
layer
in
enumerate
(
model
.
model
.
layers
):
attn
=
layer
.
self_attn
.
attn
k
,
v
,
pt
,
bh
=
self
.
paged_cache
.
layer_slices
(
layer_index
)
attn
.
k_cache
=
k
attn
.
v_cache
=
v
attn
.
page_table
=
pt
attn
.
bh_seq_lens
=
bh
attn
.
page_size
=
self
.
page_size
def
get_num_pages
(
self
,
frac
:
float
,
n_logical_pages_max
:
int
):
free
,
total
=
torch
.
cuda
.
mem_get_info
()
used
=
total
-
free
stats
=
torch
.
cuda
.
memory_stats
()
peak
=
int
(
stats
[
"allocated_bytes.all.peak"
])
current
=
int
(
stats
[
"allocated_bytes.all.current"
])
bytes_for_kv_budget
=
int
(
total
*
frac
*
0.9
)
-
used
-
peak
+
current
if
bytes_for_kv_budget
<=
0
:
# Standalone compactor: ``frac`` is a fraction of total VRAM. When a second
# engine shares the GPU with vLLM (shared weights), most VRAM is already
# committed; the formula above goes negative. Fall back to a slice of
# *currently free* memory for the compactor KV pool.
free_frac
=
float
(
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC"
,
"0.55"
)
)
free_frac
=
max
(
0.05
,
min
(
free_frac
,
0.95
))
bytes_for_kv_budget
=
int
(
free
*
free_frac
)
logger
.
warning
(
"KV cache budget from gpu_memory_utilization (%.2f) is exhausted "
"(%.2f MiB free on device); using %.0f%% of free memory (~%.2f MiB) "
"for compactor KV (set VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC to adjust)."
,
frac
,
free
/
(
1024
**
2
),
free_frac
*
100
,
bytes_for_kv_budget
/
(
1024
**
2
),
)
if
bytes_for_kv_budget
<=
0
:
raise
RuntimeError
(
"Insufficient memory for compactor KV cache: no free GPU memory left "
"after the primary vLLM engine. Lower vLLM gpu_memory_utilization or "
"max_model_len, shorten prompts, or run compactor-only / vLLM-only "
"sessions. Raising gpu_memory_utilization here does not help."
)
# page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
int32_sz
=
torch
.
empty
((),
dtype
=
torch
.
int32
).
element_size
()
# 4
page_table_bytes_per_layer
=
(
self
.
max_num_batches
*
self
.
num_kv_heads
*
n_logical_pages_max
*
int32_sz
# page_table
+
self
.
max_num_batches
*
self
.
num_kv_heads
*
int32_sz
)
total_page_table_bytes
=
self
.
num_layers
*
page_table_bytes_per_layer
kv_bytes_net
=
bytes_for_kv_budget
-
total_page_table_bytes
if
kv_bytes_net
<=
0
:
# Tight VRAM: metadata alone can exceed the first budget; reserve page
# tables plus a slice of remaining free for KV tensors.
bytes_for_kv_budget
=
min
(
int
(
free
*
0.95
),
total_page_table_bytes
+
max
(
int
(
free
*
0.25
),
8
*
1024
*
1024
),
)
kv_bytes_net
=
bytes_for_kv_budget
-
total_page_table_bytes
if
kv_bytes_net
<=
0
:
raise
RuntimeError
(
"page-table footprint exceeds available GPU memory for compactor KV. "
f
"Reduce vLLM max_num_seqs (compactor uses
{
self
.
max_num_batches
}
) "
f
"or max_model_len (
{
self
.
max_model_len
}
), or free GPU memory."
)
dtype_sz
=
torch
.
empty
((),
dtype
=
self
.
model_dtype
).
element_size
()
bytes_per_page_across_layers
=
self
.
num_layers
*
(
2
*
self
.
page_size
*
self
.
head_dim
*
dtype_sz
)
return
max
(
1
,
kv_bytes_net
//
bytes_per_page_across_layers
)
def
estimate_max_batched_tokens
(
self
,
warmup_tokens
:
int
,
bytes_used_before_warmup
:
int
,
bytes_peak_after_warmup
:
int
,
)
->
int
:
"""
Estimate the max total number of tokens that can be processed concurrently
without OOM.
"""
assert
warmup_tokens
>
0
,
"warmup_tokens must be > 0"
# activation bytes per token
warmup_delta
=
max
(
0
,
int
(
bytes_peak_after_warmup
)
-
int
(
bytes_used_before_warmup
)
)
bytes_per_token
=
max
(
1
,
(
warmup_delta
+
warmup_tokens
-
1
)
//
warmup_tokens
)
free
,
total
=
torch
.
cuda
.
mem_get_info
()
target
=
int
(
total
*
self
.
gpu_frac
)
used_now
=
int
(
total
-
free
)
# reserve headroom equal to the gap between peak and current allocations seen so far
stats
=
torch
.
cuda
.
memory_stats
()
peak_cur
=
int
(
stats
.
get
(
"allocated_bytes.all.peak"
,
0
))
cur_now
=
int
(
stats
.
get
(
"allocated_bytes.all.current"
,
0
))
cushion
=
max
(
0
,
peak_cur
-
cur_now
)
activation_budget
=
int
(
max
(
0
,
target
-
used_now
-
cushion
)
*
0.95
)
max_tokens_per_batch
=
activation_budget
//
bytes_per_token
max_tokens_in_cache
=
(
self
.
num_pages
*
self
.
page_size
)
//
self
.
num_kv_heads
# round to lower multiple of page size
max_tokens_per_batch
=
(
max_tokens_per_batch
//
self
.
page_size
)
*
self
.
page_size
max_tokens_in_cache
=
(
max_tokens_in_cache
//
self
.
page_size
)
*
self
.
page_size
# When vLLM shares the same GPU, ``used_now`` often exceeds ``target`` (same
# situation as ``get_num_pages``), so activation_budget is ~0 and
# ``max_tokens_per_batch`` rounds to 0 or one page. The min(...) would then
# cap prefill at ~page_size tokens (e.g. 32) even though the compactor KV pool
# is large — no prompt longer than that can be scheduled. Prefer KV capacity
# (capped by max_model_len) whenever activation math yields only a token or two.
if
(
max_tokens_in_cache
>
0
and
max_tokens_per_batch
<=
self
.
page_size
and
max_tokens_in_cache
>
max_tokens_per_batch
):
max_tokens_per_batch
=
min
(
max_tokens_in_cache
,
self
.
max_model_len
)
self
.
max_batched_tokens
=
min
(
max_tokens_in_cache
,
max_tokens_per_batch
)
# Last resort: allow at least one page when KV exists but min(...) is still 0.
if
self
.
max_batched_tokens
==
0
and
self
.
num_pages
>
0
and
max_tokens_in_cache
>
0
:
self
.
max_batched_tokens
=
min
(
max_tokens_in_cache
,
self
.
page_size
)
return
self
.
max_batched_tokens
@
property
def
num_free_batches
(
self
)
->
int
:
return
len
(
self
.
paged_cache
.
free_batches
)
@
property
def
num_free_pages
(
self
)
->
int
:
return
min
(
len
(
fp
)
for
fp
in
self
.
paged_cache
.
free_pages
)
def
reclaim_pages
(
self
,
seq_ids_to_reclaim
:
Iterable
[
int
],
future_reserved_buffer
:
List
[
int
]
|
torch
.
Tensor
,
)
->
int
:
approximate_bytes_freed
=
0
for
i
,
seq_id
in
enumerate
(
seq_ids_to_reclaim
):
batch_idx
=
self
.
seq_id_to_batch
[
seq_id
]
approximate_bytes_freed
+=
self
.
paged_cache
.
reclaim_pages
(
batch_idx
,
future_reserved_buffer
[
i
]
)
return
approximate_bytes_freed
vllm/kvprune_legacy_save/core/model_runner.py
0 → 100644
View file @
2b7160c6
import
atexit
import
logging
import
os
import
inspect
from
typing
import
Any
,
List
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
from
vllm.kvprune.attention.sparse_decode_kernel
import
num_splits_heuristic
from
vllm.kvprune.compression.compression_config
import
BatchCompressionParams
from
vllm.kvprune.config.constants
import
RESERVED_BATCH
from
vllm.kvprune.config.engine_config
import
LLMConfig
,
KvpruneAttentionSchedule
from
vllm.kvprune.core.memory_manager
import
KVCacheManager
from
vllm.kvprune.core.scheduler
import
Scheduler
from
vllm.kvprune.layers.sampler
import
Sampler
from
vllm.kvprune.models
import
MODEL_REGISTRY
from
vllm.kvprune.utils.arguments
import
(
DecodeBatchArguments
,
DecodeBatchOutput
,
PackedTensorArguments
,
PrefillBatchArguments
,
)
from
vllm.kvprune.utils.context
import
CompressionContext
,
reset_context
,
set_context
from
vllm.kvprune.utils.kv_dist
import
barrier_sync
,
broadcast_from_tp_rank0
from
vllm.kvprune.utils.sequence
import
Sequence
from
torch.multiprocessing
import
Event
from
tqdm
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
class
ModelRunner
:
"""Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
def
__init__
(
self
,
config
:
LLMConfig
,
rank
:
int
,
batch_ready
:
Optional
[
Event
]
=
None
,
peer_events
:
List
[
Event
]
=
None
,
external_model
:
Optional
[
nn
.
Module
]
=
None
,
*
,
embedded_in_vllm_worker
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
self
.
config
=
config
self
.
embedded_in_vllm_worker
=
embedded_in_vllm_worker
if
embedded_in_vllm_worker
:
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
tp_ws
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
tp_ws
!=
config
.
tensor_parallel_size
:
raise
RuntimeError
(
f
"tensor parallel world size
{
tp_ws
}
!= "
f
"LLMConfig.tensor_parallel_size
{
config
.
tensor_parallel_size
}
"
)
self
.
rank
=
tp_rank
_dev
=
device
if
device
is
not
None
else
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
if
not
dist
.
is_initialized
():
raise
RuntimeError
(
"embedded_in_vllm_worker requires torch.distributed to be "
"initialized (vLLM worker)."
)
if
dist
.
get_world_size
()
!=
tp_ws
:
raise
NotImplementedError
(
"KV-prune compactor embedded in vLLM currently requires "
"dist.get_world_size() == tensor_parallel_size "
"(pipeline_parallel_size=1, data_parallel_size=1). "
f
"Got dist.get_world_size()=
{
dist
.
get_world_size
()
}
, "
f
"tp_ws=
{
tp_ws
}
."
)
else
:
self
.
rank
=
rank
_dev
=
device
if
device
is
not
None
else
torch
.
device
(
f
"cuda:
{
rank
}
"
)
self
.
_device
=
_dev
assert
config
.
eos_token_ids
is
not
None
and
len
(
config
.
eos_token_ids
)
>
0
,
(
"LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
)
self
.
_stop_token_ids
=
torch
.
tensor
(
config
.
eos_token_ids
,
dtype
=
torch
.
int64
,
device
=
_dev
)
hf_config
=
config
.
hf_config
self
.
enforce_eager
=
config
.
enforce_eager
if
config
.
attention_schedule
==
KvpruneAttentionSchedule
.
PDFA
:
if
not
self
.
enforce_eager
and
self
.
rank
==
0
:
logger
.
info
(
"attention_schedule=PDFA: disabling compactor decode CUDA graphs "
"(FlashAttention decode path)."
)
self
.
enforce_eager
=
True
# Embedded in vLLM worker (TP>1): respect :attr:`LLMConfig.enforce_eager` from
# ``v1_tp_runner._apply_compactor_env_overrides``. Set
# ``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` to force eager if graph replay is unstable
# with shared vLLM VRAM / streams / NCCL on your stack.
if
embedded_in_vllm_worker
:
_tp_graph
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_TP_EMBEDDED_GRAPH"
,
"1"
).
strip
().
lower
()
if
_tp_graph
in
(
"0"
,
"false"
,
"no"
):
if
not
self
.
enforce_eager
:
logger
.
info
(
"embedded_in_vllm_worker: VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0 → "
"forcing compactor enforce_eager=True (skip compactor CUDA graph "
"capture)."
)
self
.
enforce_eager
=
True
self
.
world_size
=
config
.
tensor_parallel_size
self
.
leverage_sketch_size
=
config
.
leverage_sketch_size
self
.
show_progress_bar
=
config
.
show_progress_bar
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_model_len
=
config
.
max_model_len
self
.
num_layers
=
hf_config
.
num_hidden_layers
self
.
model_dtype
=
hf_config
.
torch_dtype
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
init_kwargs
=
{}
if
not
embedded_in_vllm_worker
:
if
"device_id"
in
inspect
.
signature
(
dist
.
init_process_group
).
parameters
:
init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
"nccl"
,
f
"tcp://localhost:
{
config
.
nccl_port
}
"
,
world_size
=
self
.
world_size
,
rank
=
rank
,
**
init_kwargs
,
)
else
:
ws
=
dist
.
get_world_size
()
if
ws
!=
self
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized with "
f
"world_size=
{
ws
}
, but compactor ModelRunner expects "
f
"tensor_parallel_size=
{
self
.
world_size
}
. "
"Use tensor_parallel_size matching the active process group "
"(typically 1 when sharing weights with vLLM)."
)
torch
.
cuda
.
set_device
(
_dev
)
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
hf_config
.
torch_dtype
)
torch
.
set_default_device
(
"cuda"
)
model_type
=
hf_config
.
model_type
if
external_model
is
not
None
:
self
.
model
=
external_model
else
:
self
.
model
=
MODEL_REGISTRY
[
model_type
](
hf_config
)
self
.
model
.
load_model
(
config
.
path
,
use_tqdm
=
self
.
is_master
and
self
.
show_progress_bar
)
self
.
sampler
=
Sampler
()
pre_warmup_mem
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.current"
,
0
)
# No paged KV yet: FA-only varlen path (see :meth:`warmup`).
self
.
warmup
(
num_warmup_tokens
=
self
.
max_model_len
,
with_kv
=
False
)
post_warmup_peak
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
kv_manager
=
KVCacheManager
(
self
.
rank
,
config
,
device
=
str
(
self
.
_device
)
)
self
.
kv_manager
.
init_cache
(
self
.
model
)
self
.
store_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
torch
.
cuda
.
Stream
()
torch
.
set_default_device
(
"cpu"
)
torch
.
set_default_dtype
(
default_dtype
)
self
.
batch_ready
=
batch_ready
self
.
peer_events
=
peer_events
if
peer_events
is
not
None
else
[]
# Embedded TP peers: session end is signaled via TP-group broadcast in
# maybe_release_peers (no multiprocessing.Event — not pickleable over RPC).
self
.
_embedded_peer_continue
=
True
self
.
captured_graphs
=
{}
self
.
min_captured_len
=
{}
self
.
max_batched_tokens
=
self
.
kv_manager
.
estimate_max_batched_tokens
(
self
.
max_model_len
,
pre_warmup_mem
,
post_warmup_peak
)
if
self
.
is_master
:
logger
.
info
(
f
"Estimated max batched tokens of
{
self
.
max_batched_tokens
}
"
)
self
.
warmup
(
num_warmup_tokens
=
self
.
max_model_len
,
with_kv
=
True
)
if
not
self
.
enforce_eager
:
bs
=
[
1
<<
i
for
i
in
range
(
self
.
max_num_batches
.
bit_length
())]
for
bs
in
(
tqdm
(
bs
,
desc
=
"Capturing CUDA Graphs"
)
if
self
.
is_master
and
self
.
show_progress_bar
else
bs
):
for
seq_len
in
[
1024
,
4096
,
8192
,
16384
]:
self
.
capture_cudagraph
(
bs
,
seq_len
)
if
not
self
.
captured_graphs
:
logger
.
warning
(
"No compactor CUDA graphs were captured (KV budget tight or "
"allocate_sequences failed during capture). Using eager decode "
"for this session."
)
self
.
enforce_eager
=
True
self
.
packed_args
=
PackedTensorArguments
(
rank
=
self
.
rank
,
max_batched_tokens
=
self
.
max_batched_tokens
,
config
=
self
.
config
,
device
=
self
.
_device
,
use_tp_group_for_collectives
=
embedded_in_vllm_worker
,
)
atexit
.
register
(
self
.
exit
)
@
torch
.
inference_mode
()
def
warmup
(
self
,
num_warmup_tokens
:
int
,
*
,
with_kv
:
bool
):
sched
=
(
self
.
config
.
attention_schedule
if
with_kv
else
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
)
if
self
.
rank
==
0
:
logger
.
info
(
"Warming up compactor attention (%s KV init): schedule=%s"
,
"after"
if
with_kv
else
"before"
,
sched
.
name
,
)
device
=
self
.
_device
input_ids
=
torch
.
tensor
(
[
self
.
config
.
eos
]
*
num_warmup_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
positions
=
torch
.
arange
(
num_warmup_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
cu_seqlens_q
=
torch
.
tensor
(
[
0
,
num_warmup_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
tensor
(
[
0
,
num_warmup_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
if
with_kv
:
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
[
-
1
],
[
num_warmup_tokens
]
)
assert
success
else
:
batch_mapping
=
None
set_context
(
is_prefill
=
True
,
do_compression
=
False
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
cu_seqlens_q_host
=
(
0
,
num_warmup_tokens
),
cu_seqlens_k_host
=
(
0
,
num_warmup_tokens
),
max_seqlen_q
=
num_warmup_tokens
,
max_seqlen_k
=
num_warmup_tokens
,
batch_mapping
=
batch_mapping
,
attention_schedule
=
sched
,
)
for
_
in
range
(
2
):
torch
.
cuda
.
reset_peak_memory_stats
()
h
=
self
.
model
(
input_ids
,
positions
)
self
.
model
.
compute_logits
(
h
)
barrier_sync
(
use_tp_group
=
self
.
embedded_in_vllm_worker
)
if
with_kv
:
self
.
kv_manager
.
paged_cache
.
bh_seq_lens
.
index_fill_
(
1
,
batch_mapping
.
to
(
torch
.
long
),
0
)
reset_context
()
if
with_kv
:
self
.
kv_manager
.
free_sequences
([
-
1
])
def
exit
(
self
):
if
getattr
(
self
,
"_exited"
,
False
):
return
self
.
_exited
=
True
try
:
if
hasattr
(
self
,
"captured_graphs"
):
self
.
captured_graphs
.
clear
()
finally
:
if
getattr
(
self
,
"embedded_in_vllm_worker"
,
False
):
return
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
def
loop
(
self
):
while
True
:
if
self
.
batch_ready
.
wait
(
1.0
):
self
.
_process_batches_peer
()
@
torch
.
inference_mode
()
def
run_prefill
(
self
,
prefill_args
:
PrefillBatchArguments
,
batch_mapping
:
torch
.
Tensor
):
assert
prefill_args
.
B
>
0
and
prefill_args
.
N
>
0
max_bh_len
=
(
self
.
kv_manager
.
paged_cache
.
bh_seq_lens
.
index_select
(
1
,
index
=
batch_mapping
)
.
max
()
.
item
()
)
compression_context
=
CompressionContext
(
compression_method
=
prefill_args
.
compression_method
,
compression_chunk_size
=
prefill_args
.
compression_chunk_size
,
batch_tokens_to_retain
=
prefill_args
.
batch_tokens_to_retain
,
max_tokens_to_retain
=
prefill_args
.
max_tokens_to_retain
,
context_lens
=
prefill_args
.
context_lens
.
tolist
(),
PHI
=
prefill_args
.
PHI
,
sketch_dimension
=
self
.
leverage_sketch_size
,
protected_first_tokens
=
prefill_args
.
protected_first
,
protected_last_tokens
=
prefill_args
.
protected_last
,
compression_ratio
=
prefill_args
.
compression_ratio
,
)
cu_q_host
=
tuple
(
int
(
x
)
for
x
in
prefill_args
.
cu_seqlens_q
.
detach
().
cpu
().
view
(
-
1
).
tolist
()
)
cu_k_host
=
tuple
(
int
(
x
)
for
x
in
prefill_args
.
cu_seqlens_k
.
detach
().
cpu
().
view
(
-
1
).
tolist
()
)
set_context
(
is_prefill
=
True
,
do_compression
=
prefill_args
.
do_compression
,
cu_seqlens_q
=
prefill_args
.
cu_seqlens_q
,
cu_seqlens_k
=
prefill_args
.
cu_seqlens_k
,
cu_seqlens_q_host
=
cu_q_host
,
cu_seqlens_k_host
=
cu_k_host
,
max_seqlen_q
=
prefill_args
.
max_seqlen_q
,
max_seqlen_k
=
prefill_args
.
max_seqlen_k
,
batch_mapping
=
batch_mapping
,
max_bh_len
=
max_bh_len
,
compression_context
=
compression_context
,
STORE_STREAM
=
self
.
store_stream
,
attention_schedule
=
self
.
config
.
attention_schedule
,
)
# int32 token ids break vLLM-delegated embedding (expects long indices) on some paths.
_iid
=
(
prefill_args
.
input_ids
if
prefill_args
.
input_ids
.
dtype
==
torch
.
int64
else
prefill_args
.
input_ids
.
long
()
)
_pos
=
(
prefill_args
.
positions
if
prefill_args
.
positions
.
dtype
==
torch
.
int64
else
prefill_args
.
positions
.
long
()
)
hidden
=
self
.
model
(
_iid
,
_pos
)
logits
=
self
.
model
.
compute_logits
(
hidden
)
reset_context
()
return
logits
def
maybe_broadcast
(
self
,
tensor
:
torch
.
Tensor
,
*
,
label
:
str
=
"tensor"
)
->
None
:
if
self
.
world_size
>
1
:
broadcast_from_tp_rank0
(
tensor
,
use_tp_group
=
self
.
embedded_in_vllm_worker
)
return
None
def
maybe_release_peers
(
self
,
do_release
=
False
):
if
self
.
world_size
<=
1
:
return
if
self
.
embedded_in_vllm_worker
:
flag
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
self
.
_device
)
if
self
.
is_master
:
flag
[
0
]
=
0
if
do_release
else
1
broadcast_from_tp_rank0
(
flag
,
use_tp_group
=
True
)
if
not
self
.
is_master
:
self
.
_embedded_peer_continue
=
bool
(
flag
[
0
].
item
())
barrier_sync
(
use_tp_group
=
True
)
return
if
self
.
is_master
:
if
do_release
:
for
event
in
self
.
peer_events
:
event
.
clear
()
barrier_sync
(
use_tp_group
=
False
)
else
:
barrier_sync
(
use_tp_group
=
False
)
def
_peer_outer_loop_active
(
self
)
->
bool
:
if
self
.
batch_ready
is
not
None
:
return
self
.
batch_ready
.
is_set
()
if
self
.
embedded_in_vllm_worker
:
return
self
.
_embedded_peer_continue
return
False
@
torch
.
inference_mode
()
def
generate
(
self
,
all_sequences
:
List
[
Sequence
],
batch_compression_params
:
Optional
[
BatchCompressionParams
]
=
None
,
):
assert
self
.
is_master
,
"generate can only be called on the master process"
if
not
self
.
embedded_in_vllm_worker
:
for
begin_execution_event
in
self
.
peer_events
:
begin_execution_event
.
set
()
if
batch_compression_params
is
None
:
batch_compression_params
=
BatchCompressionParams
()
self
.
_process_batches_master
(
all_sequences
,
batch_compression_params
)
@
property
def
is_master
(
self
):
return
self
.
rank
==
0
@
torch
.
inference_mode
()
def
_process_batches_master
(
self
,
all_sequences
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
,
):
assert
self
.
is_master
compression_details
=
f
"Applying Compression Method:
{
batch_compression_params
.
compression_method
}
"
if
any
(
seq
.
compression_params
.
compression_ratio
<
1.0
for
seq
in
all_sequences
):
logger
.
info
(
compression_details
)
scheduler
=
Scheduler
(
all_sequences
=
all_sequences
,
kv_manager
=
self
.
kv_manager
,
use_tqdm
=
self
.
show_progress_bar
,
)
decode_batch
=
DecodeBatchArguments
()
decode_flags
=
torch
.
empty
(
2
,
dtype
=
torch
.
int32
,
device
=
self
.
_device
)
while
not
scheduler
.
is_finished
():
sequences
=
scheduler
.
get_prefill_batch
()
if
not
sequences
:
if
scheduler
.
pending_sequence_ids
:
raise
RuntimeError
(
"KV-prune compactor cannot schedule any prefill (KV/token budget). "
f
"max_batched_tokens=
{
self
.
kv_manager
.
max_batched_tokens
}
, "
f
"pending_sequences=
{
len
(
scheduler
.
pending_sequence_ids
)
}
. "
"Lower v1 gpu_memory_utilization / max_model_len, set "
"VLLM_KVPRUNE_RELEASE_V1_KV=1 to discard v1 KV (sleep+wake), "
"or free GPU memory. Diagnostics: "
f
"
{
scheduler
.
diagnose_prefill_failure
()
}
"
)
# Pending is empty: either finished or decode-only continuation.
if
decode_batch
.
token_ids
is
None
:
break
run_decode
=
True
occupancy
=
-
1
else
:
seq_ids_cpu
=
[
seq
.
seq_id
for
seq
in
sequences
]
scheduler
.
add_running_sequence_ids
(
seq_ids_cpu
,
update_status
=
True
)
temps
=
torch
.
tensor
(
[
s
.
sampling_params
.
temperature
for
s
in
sequences
],
dtype
=
torch
.
float32
,
pin_memory
=
True
,
).
to
(
device
=
self
.
_device
,
non_blocking
=
True
)
prefill_arguments
=
self
.
packed_args
.
build_prefill_args
(
sequences
,
batch_compression_params
=
batch_compression_params
)
max_ctx_lens
=
(
prefill_arguments
.
max_new_tokens
+
prefill_arguments
.
context_lens
)
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
seq_ids_cpu
,
max_ctx_lens
.
tolist
()
)
assert
success
,
"failed to allocate pages for sequences"
logits
=
self
.
run_prefill
(
prefill_arguments
,
batch_mapping
)
# Must match prefill `positions` dtype (int64). `context_lens` is int32
# from the packed buffer; using int32 here breaks RoPE indexing
# (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
positions
=
prefill_arguments
.
context_lens
.
to
(
dtype
=
torch
.
int64
)
token_ids
=
self
.
sampler
(
logits
,
temps
)
# Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
# reads bh_seq_lens on the default stream and must not race.
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
# TODO: synchronize page counts accross dist
if
self
.
world_size
==
1
:
self
.
kv_manager
.
reclaim_pages
(
seq_ids_cpu
,
prefill_arguments
.
max_new_tokens
)
# with logging_redirect_tqdm():
# logger.info(
# f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
# )
if
scheduler
.
any_pending_sequences
():
num_pending_batches
=
(
0
if
decode_batch
.
token_ids
is
None
else
decode_batch
.
token_ids
.
shape
[
0
]
)
occupancy
=
int
((
num_pending_batches
+
len
(
seq_ids_cpu
))
*
0.66
)
else
:
occupancy
=
-
1
run_decode
=
not
scheduler
.
can_prefill_another_batch
()
decode_batch
=
decode_batch
.
update
(
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
prefill_arguments
.
seq_ids
,
temps
,
occupancy
,
)
if
self
.
world_size
>
1
:
decode_flags
[
0
]
=
int
(
run_decode
)
decode_flags
[
1
]
=
occupancy
self
.
maybe_broadcast
(
decode_flags
,
label
=
"decode_flags"
)
if
not
run_decode
:
continue
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
decode_output
,
decode_batch
=
self
.
run_decode_loop
(
decode_batch
)
finished_sequence_ids
=
scheduler
.
get_finished_sequence_ids_from_unfinished
(
decode_batch
.
seq_ids
.
tolist
()
)
scheduler
.
record_finished_sequence_ids
(
finished_sequence_ids
,
update_status
=
True
)
self
.
kv_manager
.
free_sequences
(
finished_sequence_ids
)
self
.
maybe_release_peers
(
scheduler
.
is_finished
())
scheduler
.
update_sequences
(
decode_output
.
output_tokens
.
tolist
(),
decode_output
.
output_seq_ids
.
tolist
(),
)
scheduler
.
close
()
@
torch
.
inference_mode
()
def
run_peer_session
(
self
)
->
None
:
"""Non-master TP ranks: run one peer session (used when embedded in vLLM)."""
if
self
.
embedded_in_vllm_worker
:
self
.
_embedded_peer_continue
=
True
self
.
_process_batches_peer
()
@
torch
.
inference_mode
()
def
_process_batches_peer
(
self
):
assert
not
self
.
is_master
scheduler
=
Scheduler
([],
kv_manager
=
self
.
kv_manager
)
decode_batch
=
DecodeBatchArguments
()
decode_flags
=
torch
.
empty
(
2
,
dtype
=
torch
.
int32
,
device
=
self
.
_device
)
while
self
.
_peer_outer_loop_active
():
prefill_arguments
=
self
.
packed_args
.
build_prefill_args
()
B
=
prefill_arguments
.
B
max_ctx_lens
=
(
prefill_arguments
.
max_new_tokens
+
prefill_arguments
.
context_lens
)
seq_ids_cpu
=
prefill_arguments
.
seq_ids
.
tolist
()
scheduler
.
add_running_sequence_ids
(
seq_ids_cpu
)
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
seq_ids_cpu
,
max_ctx_lens
.
tolist
()
)
assert
success
,
"failed to allocate pages for sequences"
self
.
run_prefill
(
prefill_arguments
,
batch_mapping
)
positions
=
prefill_arguments
.
context_lens
.
to
(
dtype
=
torch
.
int64
)
self
.
maybe_broadcast
(
decode_flags
,
label
=
"decode_flags"
)
run_decode
=
bool
(
decode_flags
[
0
].
item
())
occupancy
=
int
(
decode_flags
[
1
].
item
())
token_ids
=
torch
.
empty
(
B
,
dtype
=
torch
.
int64
,
device
=
self
.
_device
)
decode_batch
=
decode_batch
.
update
(
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
prefill_arguments
.
seq_ids
,
None
,
# temps not used in peer process
occupancy
,
)
if
not
run_decode
:
continue
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
_
,
decode_batch
=
self
.
run_decode_loop
(
decode_batch
)
finished_sequence_ids
=
scheduler
.
get_finished_sequence_ids_from_unfinished
(
decode_batch
.
seq_ids
.
tolist
()
)
scheduler
.
record_finished_sequence_ids
(
finished_sequence_ids
)
self
.
kv_manager
.
free_sequences
(
finished_sequence_ids
)
self
.
maybe_release_peers
()
scheduler
.
close
()
@
torch
.
inference_mode
()
def
run_decode_loop
(
self
,
decode_batch
:
DecodeBatchArguments
,
)
->
tuple
[
DecodeBatchOutput
,
DecodeBatchArguments
]:
if
self
.
is_master
:
num_stashed_batches
=
decode_batch
.
num_stashed_batches
tok_buffer
=
[
decode_batch
.
token_ids
[
num_stashed_batches
:].
to
(
"cpu"
,
non_blocking
=
True
)
]
seq_buffer
=
[
decode_batch
.
seq_ids
[
num_stashed_batches
:].
to
(
"cpu"
,
non_blocking
=
True
)
]
while
True
:
self
.
maybe_broadcast
(
decode_batch
.
token_ids
,
label
=
"decode_token_ids"
)
not_stopped
=
~
torch
.
isin
(
decode_batch
.
token_ids
,
self
.
_stop_token_ids
)
running_batches
=
(
decode_batch
.
positions
<
decode_batch
.
max_ctx_lens
)
&
(
not_stopped
)
decode_batch
.
token_ids
=
torch
.
masked_select
(
decode_batch
.
token_ids
,
running_batches
)
decode_batch
.
positions
=
torch
.
masked_select
(
decode_batch
.
positions
,
running_batches
)
decode_batch
.
batch_mapping
=
torch
.
masked_select
(
decode_batch
.
batch_mapping
,
running_batches
)
decode_batch
.
max_ctx_lens
=
torch
.
masked_select
(
decode_batch
.
max_ctx_lens
,
running_batches
)
decode_batch
.
seq_ids
=
torch
.
masked_select
(
decode_batch
.
seq_ids
,
running_batches
)
if
self
.
is_master
:
decode_batch
.
temps
=
torch
.
masked_select
(
decode_batch
.
temps
,
running_batches
)
num_remaining
=
decode_batch
.
token_ids
.
numel
()
if
(
num_remaining
==
0
or
num_remaining
<=
decode_batch
.
desired_batch_occupancy
):
decode_batch
.
num_stashed_batches
=
num_remaining
break
logits
=
self
.
_decode_step_logits
(
decode_batch
)
if
self
.
is_master
:
decode_batch
.
token_ids
=
self
.
sampler
(
logits
,
decode_batch
.
temps
)
tok_buffer
.
append
(
decode_batch
.
token_ids
.
to
(
"cpu"
,
non_blocking
=
True
))
seq_buffer
.
append
(
decode_batch
.
seq_ids
.
to
(
"cpu"
,
non_blocking
=
True
))
decode_batch
.
positions
+=
1
if
self
.
is_master
:
# non_blocking D2H copies must finish before cat/tolist read CPU data.
torch
.
cuda
.
synchronize
()
output
=
DecodeBatchOutput
(
output_tokens
=
torch
.
cat
(
tok_buffer
),
output_seq_ids
=
torch
.
cat
(
seq_buffer
),
)
else
:
output
=
DecodeBatchOutput
(
None
,
None
)
return
output
,
decode_batch
def
_decode_logits_eager
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
):
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
batch_mapping
,
attention_schedule
=
self
.
config
.
attention_schedule
,
)
_iid
=
input_ids
if
input_ids
.
dtype
==
torch
.
int64
else
input_ids
.
long
()
_pos
=
positions
if
positions
.
dtype
==
torch
.
int64
else
positions
.
long
()
hidden
=
self
.
model
(
_iid
,
_pos
)
return
self
.
model
.
compute_logits
(
hidden
)
@
torch
.
inference_mode
()
def
_decode_step_logits
(
self
,
decode_batch
:
DecodeBatchArguments
):
"""Graph decode when possible; otherwise eager (never raises on missing graph)."""
if
self
.
enforce_eager
or
not
self
.
captured_graphs
:
return
self
.
_decode_logits_eager
(
decode_batch
.
token_ids
,
decode_batch
.
positions
,
decode_batch
.
batch_mapping
,
)
try
:
return
self
.
run_graph_decode
(
decode_batch
.
token_ids
,
decode_batch
.
positions
,
decode_batch
.
batch_mapping
,
)
except
Exception
as
e
:
logger
.
warning
(
"CUDA graph decode failed (%s); switching to eager decode for "
"remaining steps."
,
e
,
)
self
.
enforce_eager
=
True
return
self
.
_decode_logits_eager
(
decode_batch
.
token_ids
,
decode_batch
.
positions
,
decode_batch
.
batch_mapping
,
)
@
torch
.
inference_mode
()
def
run_graph_decode
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
):
bs
=
input_ids
.
shape
[
0
]
max_k
=
int
(
positions
.
max
())
graph_dict
=
self
.
get_cuda_graph
(
bs
,
max_k
)
if
graph_dict
is
None
:
return
self
.
_decode_logits_eager
(
input_ids
,
positions
,
batch_mapping
)
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
batch_mapping
,
attention_schedule
=
self
.
config
.
attention_schedule
,
)
graph_dict
[
"input_ids"
][:
bs
]
=
input_ids
graph_dict
[
"positions"
][:
bs
]
=
positions
graph_dict
[
"batch_mapping"
].
fill_
(
RESERVED_BATCH
)
graph_dict
[
"batch_mapping"
][:
bs
]
=
batch_mapping
graph_dict
[
"graph"
].
replay
()
logits_out
=
graph_dict
[
"logits"
]
return
logits_out
[:
bs
].
contiguous
()
@
torch
.
inference_mode
()
def
capture_cudagraph
(
self
,
batch_size
:
int
,
max_seqlen_k
:
int
):
barrier_sync
(
use_tp_group
=
self
.
embedded_in_vllm_worker
)
device
=
torch
.
device
(
"cuda"
)
logger
.
debug
(
f
"Capturing CUDA graph for batch size
{
batch_size
}
(
{
max_seqlen_k
}
tokens)"
)
_g_input_ids
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
_g_positions
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
_g_hidden
=
None
key_split
=
num_splits_heuristic
(
batch_size
*
self
.
kv_manager
.
num_kv_heads
,
max_seq_len
=
max_seqlen_k
,
num_sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
,
max_splits
=
12
,
)
success
,
_g_batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
list
(
range
(
batch_size
)),
[
256
]
*
batch_size
)
if
not
success
:
# Shared GPU with vLLM: compactor KV pool is small; large batch capture
# often cannot reserve [256]*batch_size per sequence. Skip this graph.
logger
.
warning
(
"Skipping CUDA graph capture for batch_size=%s max_seqlen_k=%s "
"(KV allocate_sequences failed; decode will use eager or other graphs)."
,
batch_size
,
max_seqlen_k
,
)
barrier_sync
(
use_tp_group
=
self
.
embedded_in_vllm_worker
)
return
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
_g_batch_mapping
,
key_split
=
key_split
,
attention_schedule
=
self
.
config
.
attention_schedule
,
)
_gw
=
self
.
model
(
_g_input_ids
,
_g_positions
)
self
.
model
.
compute_logits
(
_gw
)
barrier_sync
(
use_tp_group
=
self
.
embedded_in_vllm_worker
)
decode_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
decode_graph
):
_g_hidden
=
self
.
model
(
_g_input_ids
,
_g_positions
)
_g_logits
=
self
.
model
.
compute_logits
(
_g_hidden
)
graph_vars
=
{
"graph"
:
decode_graph
,
"input_ids"
:
_g_input_ids
,
"positions"
:
_g_positions
,
"batch_mapping"
:
_g_batch_mapping
,
"hidden"
:
_g_hidden
,
"logits"
:
_g_logits
,
"key_split"
:
key_split
,
}
if
batch_size
not
in
self
.
captured_graphs
:
self
.
captured_graphs
[
batch_size
]
=
{}
self
.
min_captured_len
[
batch_size
]
=
float
(
"inf"
)
self
.
captured_graphs
[
batch_size
][
max_seqlen_k
]
=
graph_vars
self
.
min_captured_len
[
batch_size
]
=
min
(
max_seqlen_k
,
self
.
min_captured_len
[
batch_size
]
)
self
.
kv_manager
.
free_sequences
(
list
(
range
(
batch_size
)))
def
get_cuda_graph
(
self
,
batch_size
:
int
,
max_seqlen_k
:
int
)
->
Optional
[
dict
[
str
,
Any
]]:
"""Return a captured graph dict, or None if no compatible capture exists."""
if
not
self
.
captured_graphs
:
return
None
eligible_bs
=
[
x
for
x
in
self
.
captured_graphs
.
keys
()
if
x
>=
batch_size
]
if
not
eligible_bs
:
return
None
bs_key
=
min
(
eligible_bs
)
batch_size_graphs
=
self
.
captured_graphs
[
bs_key
]
candidates
=
[
sl
for
sl
in
batch_size_graphs
.
keys
()
if
sl
<=
max_seqlen_k
]
if
not
candidates
:
return
None
best_sl
=
max
(
candidates
)
return
batch_size_graphs
[
best_sl
]
vllm/kvprune_legacy_save/core/runtime.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
import
torch
from
vllm.forward_context
import
get_forward_context
from
vllm.kvprune.core.compression_bridge
import
(
COMPRESSION_METHOD_ID_NONE
,
compression_method_str_to_id
,
)
@
dataclass
class
KVPruneForwardState
:
"""Per-forward-pass state for KV pruning (per-layer logical lengths)."""
active
:
bool
compression_ratio_gpu
:
torch
.
Tensor
"""[num_reqs_padded] ratio in (0,1], 1.0 means no pruning for that row."""
compression_method_id_gpu
:
torch
.
Tensor
"""[num_reqs_padded] int32 — see ``compression_bridge`` ids (0=none)."""
query_start_loc
:
torch
.
Tensor
"""[num_reqs_padded + 1] int32 on device."""
num_reqs
:
int
num_reqs_padded
:
int
num_layers
:
int
logical_seq_lens_gpu
:
torch
.
Tensor
"""Logical KV length per layer (and optionally per KV head).
Shape ``[num_layers, num_reqs_padded]`` or, when ``num_kv_heads > 1``,
``[num_layers, num_reqs_padded, num_kv_heads]`` for per-head lengths.
"""
is_prefill
:
bool
device
:
torch
.
device
def
logical_seq_lens_for_layer
(
self
,
layer_idx
:
int
)
->
torch
.
Tensor
:
sl
=
self
.
logical_seq_lens_gpu
[
layer_idx
]
if
sl
.
dim
()
==
2
:
return
sl
.
max
(
dim
=-
1
).
values
return
sl
def
build_kv_prune_forward_state
(
*
,
req_ids
:
list
[
str
],
requests
:
dict
[
str
,
object
],
query_start_loc
:
torch
.
Tensor
,
num_reqs
:
int
,
num_reqs_padded
:
int
,
num_layers
:
int
,
max_num_scheduled_tokens
:
int
,
device
:
torch
.
device
,
logical_seq_lens_gpu
:
torch
.
Tensor
,
)
->
KVPruneForwardState
|
None
:
"""Build pruning state when any request uses compression_ratio < 1.0."""
if
num_reqs
<=
0
or
num_layers
<=
0
:
return
None
ratios
=
[]
method_ids
:
list
[
int
]
=
[]
active_req
=
False
for
rid
in
req_ids
[:
num_reqs
]:
req
=
requests
.
get
(
rid
)
sp
=
getattr
(
req
,
"sampling_params"
,
None
)
if
req
is
not
None
else
None
r
=
1.0
if
sp
is
None
else
float
(
getattr
(
sp
,
"compression_ratio"
,
1.0
))
if
r
<
1.0
-
1e-6
:
active_req
=
True
ratios
.
append
(
r
)
if
sp
is
None
or
r
>=
1.0
-
1e-6
:
mid
=
COMPRESSION_METHOD_ID_NONE
else
:
cm
=
getattr
(
sp
,
"compression_method"
,
"none"
)
or
"none"
mid
=
compression_method_str_to_id
(
str
(
cm
))
method_ids
.
append
(
mid
)
if
not
active_req
:
return
None
compression_ratio_gpu
=
torch
.
ones
(
(
num_reqs_padded
,),
dtype
=
torch
.
float32
,
device
=
device
)
compression_ratio_gpu
[:
num_reqs
]
=
torch
.
tensor
(
ratios
,
dtype
=
torch
.
float32
,
device
=
device
)
compression_method_id_gpu
=
torch
.
zeros
(
(
num_reqs_padded
,),
dtype
=
torch
.
int32
,
device
=
device
)
compression_method_id_gpu
[:
num_reqs
]
=
torch
.
tensor
(
method_ids
,
dtype
=
torch
.
int32
,
device
=
device
)
is_prefill
=
max_num_scheduled_tokens
>
1
return
KVPruneForwardState
(
active
=
True
,
compression_ratio_gpu
=
compression_ratio_gpu
,
compression_method_id_gpu
=
compression_method_id_gpu
,
query_start_loc
=
query_start_loc
,
num_reqs
=
num_reqs
,
num_reqs_padded
=
num_reqs_padded
,
num_layers
=
num_layers
,
logical_seq_lens_gpu
=
logical_seq_lens_gpu
,
is_prefill
=
is_prefill
,
device
=
device
,
)
def
layer_index_from_layer_name
(
layer_name
:
str
)
->
int
:
from
vllm.model_executor.models.utils
import
extract_layer_index
return
extract_layer_index
(
layer_name
)
def
get_kv_prune_state
()
->
KVPruneForwardState
|
None
:
try
:
fc
=
get_forward_context
()
except
AssertionError
:
return
None
state
=
fc
.
additional_kwargs
.
get
(
"kv_prune"
)
if
state
is
None
or
not
isinstance
(
state
,
KVPruneForwardState
)
or
not
state
.
active
:
return
None
return
state
vllm/kvprune_legacy_save/core/scheduler.py
0 → 100644
View file @
2b7160c6
import
time
from
typing
import
Iterable
,
List
from
vllm.kvprune.core.memory_manager
import
KVCacheManager
from
vllm.kvprune.utils.sequence
import
Sequence
,
SequenceStatus
from
tqdm
import
tqdm
def
cdiv
(
a
,
b
):
"""ceiling division"""
return
(
a
+
b
-
1
)
//
b
class
Scheduler
:
"""
Simple sequence scheduler for prefill + decode with a paged KV cache.
The scheduler tracks three disjoint sets of sequence IDs:
* ``pending_sequence_ids`` – sequences that have not yet been started.
* ``active_sequence_ids`` – sequences currently running.
* ``finished_sequence_ids`` – sequences that have generated all tokens.
At prefill time, :meth:`get_prefill_batch` selects a subset of pending
sequences that can fit into the available KV cache and per-step token
budget, given the constraints from the associated :class:`KVCacheManager`.
The class also handles basic bookkeeping of sequence statuses.
Args:
:param all_sequences:
Iterable of :class:`Sequence` objects to be scheduled. Each
sequence must have a unique ``seq_id``.
:param kv_manager:
A :class:`KVCacheManager` instance that this scheduler will use
to determine whether additional batches can be scheduled.
:param use_tqdm:
If True, two progress bars are created:
* "Started Batches" – increments when a sequence moves from
pending to running.
* "Finished Batches" – increments when a sequence finishes.
"""
def
__init__
(
self
,
all_sequences
:
Iterable
[
Sequence
],
kv_manager
:
KVCacheManager
,
*
,
use_tqdm
=
False
,
):
self
.
allseq_mapping
:
dict
[
int
,
Sequence
]
=
{
s
.
seq_id
:
s
for
s
in
all_sequences
}
self
.
pending_sequence_ids
:
set
[
int
]
=
set
([
s
.
seq_id
for
s
in
all_sequences
])
self
.
active_sequence_ids
:
set
[
int
]
=
set
()
self
.
finished_sequence_ids
:
set
[
int
]
=
set
()
self
.
manager
=
kv_manager
self
.
use_tqdm
=
use_tqdm
self
.
start_time
=
time
.
perf_counter
()
self
.
total_tokens_generated
=
0
self
.
total_tokens_input
=
0
self
.
pbar
=
None
if
use_tqdm
:
self
.
pbar
=
tqdm
(
total
=
len
(
self
.
pending_sequence_ids
),
desc
=
"Completed Batches"
,
)
def
get_prefill_batch
(
self
)
->
List
[
Sequence
]:
"""
Select a batch of pending sequences to prefill under KV/memory constraints.
The selection is greedy over ``pending_sequence_ids`` in iteration order.
A sequence is added to the batch if:
* The sum of its prompt length and the total prompt tokens selected so
far does not exceed ``manager.max_batched_tokens``, and
* There is at least one free KV "batch slot" left
(``manager.num_free_batches``), and
* The total number of KV pages required by the sequence's prompt +
max_new_tokens does not exceed the remaining free pages.
Returns:
:return List[Sequence]:
The list of :class:`Sequence` objects chosen for prefill in
this step. The caller is responsible for marking them as
active via :meth:`add_running_sequence_ids`.
"""
total_tok
,
sequences
=
0
,
[]
num_free_batches
,
num_free_pages
=
(
self
.
manager
.
num_free_batches
,
self
.
manager
.
num_free_pages
,
)
for
seq_id
in
self
.
pending_sequence_ids
:
seq
=
self
.
allseq_mapping
[
seq_id
]
prompt_length
=
seq
.
prompt_len
pages_needed
=
(
cdiv
(
prompt_length
+
seq
.
sampling_params
.
max_new_tokens
,
self
.
manager
.
page_size
,
)
*
self
.
manager
.
num_kv_heads
)
if
(
prompt_length
+
total_tok
<=
self
.
manager
.
max_batched_tokens
and
num_free_batches
>
0
and
pages_needed
<=
num_free_pages
):
sequences
.
append
(
seq
)
total_tok
+=
prompt_length
num_free_pages
-=
pages_needed
num_free_batches
-=
1
return
sequences
def
diagnose_prefill_failure
(
self
)
->
str
:
"""Explain why :meth:`get_prefill_batch` may return empty (debugging)."""
num_free_batches
=
self
.
manager
.
num_free_batches
num_free_pages
=
self
.
manager
.
num_free_pages
parts
=
[
f
"num_free_batches=
{
num_free_batches
}
"
,
f
"num_free_pages=
{
num_free_pages
}
"
,
f
"num_pages_per_layer=
{
getattr
(
self
.
manager
,
'num_pages'
,
None
)
}
"
,
]
seq_id
=
next
(
iter
(
self
.
pending_sequence_ids
),
None
)
if
seq_id
is
None
:
return
"; "
.
join
(
parts
)
seq
=
self
.
allseq_mapping
[
seq_id
]
pl
=
seq
.
prompt_len
mn
=
seq
.
sampling_params
.
max_new_tokens
pages_needed
=
(
cdiv
(
pl
+
mn
,
self
.
manager
.
page_size
)
*
self
.
manager
.
num_kv_heads
)
parts
.
append
(
f
"first_pending seq_id=
{
seq_id
}
prompt_len=
{
pl
}
max_new_tokens=
{
mn
}
"
f
"pages_needed~=
{
pages_needed
}
"
)
if
num_free_batches
==
0
:
parts
.
append
(
"likely_cause=no free batch slots (compactor max_num_seqs exhausted)"
)
elif
pl
>
self
.
manager
.
max_batched_tokens
:
parts
.
append
(
f
"likely_cause=prompt_len (
{
pl
}
) > max_batched_tokens "
f
"(
{
self
.
manager
.
max_batched_tokens
}
)"
)
elif
pages_needed
>
num_free_pages
:
parts
.
append
(
"likely_cause=KV pool too small: pages_needed exceeds num_free_pages "
"(raise VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC / lower v1 memory, or cap "
"compactor max_num_seqs to shrink page-table overhead)"
)
else
:
parts
.
append
(
"likely_cause=batched token sum or greedy order (another sequence may "
"block first in set iteration)"
)
return
"; "
.
join
(
parts
)
def
is_finished
(
self
)
->
bool
:
"""
Check whether all sequences have completed.
"""
return
(
len
(
self
.
pending_sequence_ids
)
==
0
and
len
(
self
.
active_sequence_ids
)
==
0
)
def
any_pending_sequences
(
self
)
->
bool
:
"""
Check whether any sequences are still pending (not yet started).
"""
return
len
(
self
.
pending_sequence_ids
)
!=
0
def
add_running_sequence_ids
(
self
,
active_sequence_ids
:
Iterable
[
int
],
*
,
update_status
:
bool
=
False
):
"""
Mark a set of sequences as active / running. This moves sequence IDs
from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
it also updates the per-sequence status and progress bar.
Args:
:param active_sequence_ids:
Iterable of sequence IDs that have been scheduled for prefill
or decode and should now be considered running.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.RUNNING`` and increment the
"Started Batches" progress bar if ``use_tqdm`` is enabled.
"""
self
.
active_sequence_ids
.
update
(
active_sequence_ids
)
self
.
pending_sequence_ids
.
difference_update
(
self
.
active_sequence_ids
)
if
update_status
:
for
seq_id
in
active_sequence_ids
:
self
.
allseq_mapping
[
seq_id
].
status
=
SequenceStatus
.
RUNNING
self
.
total_tokens_input
+=
self
.
allseq_mapping
[
seq_id
].
prompt_len
def
get_finished_sequence_ids_from_unfinished
(
self
,
unfinished_sequence_ids
:
Iterable
[
int
]
)
->
set
[
int
]:
"""
Infer which active sequences have finished given the
unfinished set (for decode steps where the caller knows
which sequences are still generating but not necessarily
which have just completed).
Args:
:param unfinished_sequence_ids:
Iterable of sequence IDs that are still running
Returns:
:return set[int]:
The inferred set of sequence IDs that transitioned from active
to finished.
"""
return
self
.
active_sequence_ids
.
difference
(
unfinished_sequence_ids
)
def
record_finished_sequence_ids
(
self
,
finished_sequence_ids
:
Iterable
[
int
],
*
,
update_status
:
bool
=
False
):
"""
Record that a set of sequences has finished generation.
This moves IDs from ``active_sequence_ids`` into
``finished_sequence_ids``.
Args:
:param finished_sequence_ids:
Iterable of sequence IDs that have completed generation and
no longer require KV cache.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.FINISHED``
"""
self
.
active_sequence_ids
.
difference_update
(
finished_sequence_ids
)
self
.
finished_sequence_ids
.
update
(
finished_sequence_ids
)
if
update_status
:
for
seq_id
in
finished_sequence_ids
:
self
.
allseq_mapping
[
seq_id
].
status
=
SequenceStatus
.
FINISHED
if
self
.
pbar
is
not
None
:
self
.
pbar
.
update
(
1
)
def
update_sequences
(
self
,
tokens
:
Iterable
[
int
],
seq_ids
:
Iterable
[
int
]):
"""
Append newly generated tokens to their corresponding sequences.
Args:
:param tokens:
Iterable of generated token IDs, one per sequence.
:param seq_ids:
Iterable of sequence IDs aligned with ``tokens``.
"""
cur_time
=
time
.
perf_counter
()
for
tok
,
seq_id
in
zip
(
tokens
,
seq_ids
):
self
.
allseq_mapping
[
seq_id
].
add_new_token
(
tok
)
self
.
total_tokens_generated
+=
1
if
self
.
pbar
is
not
None
:
self
.
pbar
.
set_description
(
f
"Throughput:
{
(
self
.
total_tokens_generated
+
self
.
total_tokens_input
)
/
(
cur_time
-
self
.
start_time
):.
2
f
}
tok/s"
)
def
close
(
self
):
if
self
.
pbar
is
not
None
:
self
.
pbar
.
close
()
def
can_prefill_another_batch
(
self
)
->
bool
:
return
len
(
self
.
get_prefill_batch
())
>
0
vllm/kvprune_legacy_save/integration/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-pruning integration: compactor ``LLMEngine`` sharing weights with :class:`~vllm.LLM`."""
from
vllm.kvprune.integration.compression_params
import
CompressionParams
__all__
=
[
"CompressionParams"
]
vllm/kvprune_legacy_save/integration/compactor_shared.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Construct compactor :class:`LLMEngine` sharing weight tensors with an in-process vLLM ``LLM``."""
from
__future__
import
annotations
import
os
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.kvprune.config.engine_config
import
LLMConfig
from
vllm.kvprune.core.llm_engine
import
LLMEngine
from
vllm.kvprune.integration.config_adapter
import
vllm_config_to_llm_config
from
vllm.kvprune.integration.vllm_model_access
import
extract_vllm_causal_lm
from
vllm.kvprune.integration.weight_tie
import
(
delegate_kvprune_compute_logits_to_vllm
,
delegate_kvprune_embed_tokens_to_vllm
,
tie_kvprune_rope_buffers_from_vllm
,
tie_kvprune_weights_from_vllm
,
)
from
vllm.kvprune.models
import
MODEL_REGISTRY
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
build_llm_config_for_compactor
(
vc
:
VllmConfig
)
->
LLMConfig
:
"""Public helper: vLLM config → compactor :class:`LLMConfig`."""
return
vllm_config_to_llm_config
(
vc
)
def
create_compactor_engine_with_shared_weights
(
llm
:
object
)
->
LLMEngine
:
"""Single GPU, TP=1: compactor ``LLMEngine`` whose weights alias vLLM tensors.
Call after the vLLM ``LLM`` has loaded weights. Requires in-process executor
(``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
"""
llm_engine
=
getattr
(
llm
,
"llm_engine"
,
None
)
if
llm_engine
is
None
:
raise
RuntimeError
(
"Expected ``llm.llm_engine``."
)
vc
:
VllmConfig
=
llm_engine
.
vllm_config
if
vc
.
parallel_config
.
tensor_parallel_size
!=
1
:
raise
ValueError
(
"Shared-weight compactor backend requires tensor_parallel_size=1"
)
cfg
=
vllm_config_to_llm_config
(
vc
)
# ``cfg.enforce_eager`` is for the compactor ``ModelRunner`` only (decode CUDA
# graphs), not v1. v1 graph capture is controlled solely by ``LLM(...,
# enforce_eager=...)`` / ``kvprune_compression=True`` on the entrypoint ``LLM``.
# Large vLLM max_num_seqs blows up compactor page-table GPU memory; sharing the GPU
# with v1 leaves little room for metadata + KV tensors. Default cap 32 so physical
# KV pages stay usable; set VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS=0 to disable cap,
# or raise (e.g. 128) if you have VRAM headroom.
_cap
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS"
,
"32"
).
strip
()
if
_cap
:
lim
=
int
(
_cap
)
if
lim
>
0
:
cfg
.
max_num_seqs
=
min
(
cfg
.
max_num_seqs
,
lim
)
# Compactor decode graphs (``enforce_eager=False``): honored for non-shared-weight
# engines. **Shared-weight** path (below) forces ``enforce_eager=True`` after
# delegating ``compute_logits`` to vLLM unless ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1``.
# Opt out of graphs for non-shared runs: ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1`` or
# ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0``.
_ce
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER"
,
""
).
strip
().
lower
()
if
_ce
in
(
"1"
,
"true"
,
"yes"
):
cfg
.
enforce_eager
=
True
logger
.
info
(
"KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 → "
"enforce_eager=True (skip compactor decode CUDA graphs)."
)
elif
_ce
in
(
"0"
,
"false"
,
"no"
):
cfg
.
enforce_eager
=
False
logger
.
info
(
"KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=0 → "
"enforce_eager=False (try compactor CUDA graph capture)."
)
else
:
_dg
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH"
,
"1"
).
strip
().
lower
()
if
_dg
in
(
"0"
,
"false"
,
"no"
):
cfg
.
enforce_eager
=
True
logger
.
info
(
"KV-prune compactor: VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 → "
"enforce_eager=True (skip compactor decode CUDA graphs)."
)
else
:
cfg
.
enforce_eager
=
False
logger
.
info
(
"KV-prune compactor: default try decode CUDA graphs; ModelRunner "
"falls back to eager if capture yields none. Set "
"VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1 or "
"VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH=0 to skip capture."
)
hf
=
cfg
.
hf_config
assert
hf
is
not
None
model_type
=
hf
.
model_type
if
model_type
not
in
MODEL_REGISTRY
:
raise
ValueError
(
f
"Compactor MODEL_REGISTRY has no entry for model_type=
{
model_type
!
r
}
; "
f
"supported:
{
sorted
(
MODEL_REGISTRY
)
}
"
)
vllm_model
=
extract_vllm_causal_lm
(
llm
)
device
=
next
(
vllm_model
.
parameters
()).
device
dtype
=
next
(
vllm_model
.
parameters
()).
dtype
# Build compactor shell on CPU first. **Do not** call ``.to(device)`` before tying:
# that allocates a full second copy of weights on GPU; tying then frees the
# duplicate but peak memory can OOM on large models. Tie first so parameters
# alias vLLM tensors directly (no extra weight VRAM).
kv_model
:
nn
.
Module
=
MODEL_REGISTRY
[
model_type
](
hf
)
tie_kvprune_weights_from_vllm
(
vllm_model
,
kv_model
)
# Buffers (e.g. RoPE tables) not in ``named_parameters`` may still be on CPU.
kv_model
.
to
(
device
=
device
,
dtype
=
dtype
)
tie_kvprune_rope_buffers_from_vllm
(
vllm_model
,
kv_model
)
delegate_kvprune_embed_tokens_to_vllm
(
vllm_model
,
kv_model
)
delegate_kvprune_compute_logits_to_vllm
(
vllm_model
,
kv_model
)
# Compactor decode CUDA graphs capture ``model.forward`` + ``compute_logits`` in one
# graph. Here ``compute_logits`` is delegated to vLLM's LM head / LogitsProcessor
# (cublas GEMM, padded vocab, etc.). Embedding that in a nested capture commonly
# fails with ``CUBLAS_STATUS_EXECUTION_FAILED`` and invalidates stream capture
# (``cudaErrorStreamCaptureInvalidated``). Default: skip graphs for this integration.
_sw_graph
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
,
"yes"
)
if
not
_sw_graph
:
cfg
.
enforce_eager
=
True
logger
.
info
(
"KV-prune shared-weight compactor: enforce_eager=True (skip compactor "
"decode CUDA graphs; logits delegated to vLLM). Set "
"VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1 only to attempt capture (often fails)."
)
return
LLMEngine
(
cfg
,
external_model
=
kv_model
)
vllm/kvprune_legacy_save/integration/compressed_generate.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-pruning (compactor) path invoked from :meth:`vllm.entrypoints.llm.LLM.generate`."""
from
__future__
import
annotations
import
os
from
collections.abc
import
Callable
,
Sequence
from
pathlib
import
Path
from
typing
import
Any
from
tqdm.auto
import
tqdm
from
transformers
import
AutoTokenizer
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
SequenceCompressionParams
,
)
from
vllm.kvprune.config.sampling_params
import
SamplingParams
as
CompactorSamplingParams
from
vllm.kvprune.core.compression_bridge
import
(
compression_method_id_to_enum
,
compression_method_str_to_id
,
)
from
vllm.kvprune.core.llm_engine
import
LLMEngine
,
_infer_stop_token_ids
from
vllm.kvprune.integration.compactor_shared
import
create_compactor_engine_with_shared_weights
from
vllm.kvprune.integration.compression_params
import
CompressionParams
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
SamplingParams
logger
=
init_logger
(
__name__
)
_MP_ENV
=
"VLLM_ENABLE_V1_MULTIPROCESSING"
_RELEASE_V1_KV_ENV
=
"VLLM_KVPRUNE_RELEASE_V1_KV"
def
_maybe_release_v1_kv_for_compactor
(
llm
:
Any
)
->
None
:
"""Optionally discard v1's KV cache so more GPU memory is free for compactor.
v1 reserves KV blocks at engine init; shared-weight compactor then competes for
the same VRAM. ``sleep(level=1)`` discards v1 KV and may offload tagged weights
per v1 sleep policy, then ``wake_up()`` reloads — compactor still ties the same
v1 tensors after.
**Default:** ``vllm.env_override`` sets ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` (no
sleep/wake; v1 KV stays on GPU). Set ``=1`` if you need extra VRAM for compactor
before the first compressed step (then ``llm.sleep`` / ``CuMemAllocator`` /
``Sleep mode freed …`` logs are expected). This does **not** remove v1's KV
reservation at init; it only runs the optional sleep/wake cycle before compactor.
Tests keep ``VLLM_KVPRUNE_RELEASE_V1_KV=0`` in ``conftest``.
"""
if
os
.
environ
.
get
(
_RELEASE_V1_KV_ENV
,
"0"
).
strip
().
lower
()
not
in
(
"1"
,
"true"
,
"yes"
,
):
return
try
:
logger
.
info
(
"%s=1: discarding v1 KV via sleep(level=1) then wake_up() "
"(reloads model weights to GPU)."
,
_RELEASE_V1_KV_ENV
,
)
llm
.
sleep
(
level
=
1
,
mode
=
"abort"
)
llm
.
wake_up
()
except
Exception
as
e
:
logger
.
warning
(
"%s: sleep/wake failed: %s"
,
_RELEASE_V1_KV_ENV
,
e
)
def
ensure_inprocess_engine_for_weight_sharing
()
->
None
:
"""Compactor must see ``worker.get_model()`` in the same process as vLLM."""
if
os
.
environ
.
get
(
_MP_ENV
,
"1"
)
!=
"0"
:
os
.
environ
[
_MP_ENV
]
=
"0"
logger
.
info
(
"KV cache pruning: set %s=0 so the model stays in-process for "
"shared-weight compactor (no manual env needed)."
,
_MP_ENV
,
)
def
_normalize_prompt_list
(
prompts
:
Any
)
->
list
[
Any
]:
if
isinstance
(
prompts
,
str
):
return
[
prompts
]
if
isinstance
(
prompts
,
dict
):
return
[
prompts
]
return
list
(
prompts
)
def
_normalize_sampling_params
(
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
,
n
:
int
,
)
->
list
[
SamplingParams
]:
if
sampling_params
is
None
:
return
[
SamplingParams
()
for
_
in
range
(
n
)]
if
isinstance
(
sampling_params
,
SamplingParams
):
return
[
sampling_params
]
*
n
sps
=
list
(
sampling_params
)
if
len
(
sps
)
!=
n
:
raise
ValueError
(
f
"sampling_params length
{
len
(
sps
)
}
!= prompts length
{
n
}
"
)
return
sps
def
_normalize_compression_params
(
compression
:
CompressionParams
|
Sequence
[
CompressionParams
]
|
None
,
n
:
int
,
)
->
list
[
CompressionParams
]:
if
compression
is
None
:
return
[
CompressionParams
(
compression_ratio
=
1.0
)
for
_
in
range
(
n
)]
if
isinstance
(
compression
,
CompressionParams
):
return
[
compression
]
*
n
comp
=
list
(
compression
)
if
len
(
comp
)
!=
n
:
raise
ValueError
(
f
"compression length
{
len
(
comp
)
}
!= prompts length
{
n
}
"
)
return
comp
def
_any_compactor
(
comps
:
list
[
CompressionParams
])
->
bool
:
return
any
(
c
.
compression_ratio
<
1.0
for
c
in
comps
)
_FORCE_COMPACTOR_PATH_ENV
=
"VLLM_KVPRUNE_FORCE_COMPACTOR_PATH"
def
_should_use_kvprune_compactor_path
(
comps
:
list
[
CompressionParams
])
->
bool
:
"""Use integrated compactor when any prompt requests compression, or when forced.
If all ``compression_ratio >= 1.0``, the default is to return ``None`` from
:func:`try_compressed_generate` and fall back to the standard v1 engine
(``Processed prompts`` loop). That hides TP/kvprune bugs behind a different
code path. Set ``VLLM_KVPRUNE_FORCE_COMPACTOR_PATH=1`` to run the same
compactor + collective RPC path as compression-on, with no KV pruning.
"""
if
_any_compactor
(
comps
):
return
True
return
os
.
environ
.
get
(
_FORCE_COMPACTOR_PATH_ENV
,
""
).
strip
().
lower
()
in
(
"1"
,
"true"
,
"yes"
,
)
def
_to_compactor_sampling
(
sp
:
SamplingParams
)
->
CompactorSamplingParams
:
mt
=
sp
.
max_tokens
if
mt
is
None
:
mt
=
16
return
CompactorSamplingParams
(
temperature
=
float
(
sp
.
temperature
),
max_new_tokens
=
int
(
mt
),
)
def
_to_sequence_compression
(
cp
:
CompressionParams
)
->
SequenceCompressionParams
:
return
SequenceCompressionParams
(
compression_ratio
=
float
(
cp
.
compression_ratio
),
protected_first_tokens
=
int
(
cp
.
protected_first_tokens
),
protected_last_tokens
=
int
(
cp
.
protected_last_tokens
),
)
def
_batch_compression_from_comps
(
comps
:
list
[
CompressionParams
])
->
BatchCompressionParams
:
for
c
in
comps
:
if
c
.
compression_ratio
<
1.0
:
mid
=
compression_method_str_to_id
(
c
.
compression_method
)
return
BatchCompressionParams
(
compression_method
=
compression_method_id_to_enum
(
mid
)
)
return
BatchCompressionParams
()
def
_kvprune_compactor_hf_tokenizer
(
llm
:
Any
):
"""HF tokenizer matching :meth:`vllm.kvprune.core.llm_engine.LLMEngine.__init__`.
Loads from the **resolved on-disk** model tree (local dir or HF cache snapshot), not
the bare repo id, to avoid redundant Hub downloads.
"""
cached
=
getattr
(
llm
,
"_kvprune_compactor_hf_tokenizer"
,
None
)
if
cached
is
not
None
:
return
cached
mc
=
llm
.
llm_engine
.
vllm_config
.
model_config
model_s
=
str
(
mc
.
model
)
src
=
model_s
try
:
p
=
Path
(
model_s
)
if
p
.
is_dir
()
and
(
p
/
"config.json"
).
is_file
():
src
=
str
(
p
.
resolve
())
else
:
from
huggingface_hub
import
snapshot_download
src
=
snapshot_download
(
repo_id
=
model_s
,
local_files_only
=
False
)
except
Exception
:
src
=
model_s
hf_cfg
=
mc
.
hf_config
_trust
=
bool
(
getattr
(
hf_cfg
,
"trust_remote_code"
,
False
))
if
hf_cfg
is
not
None
else
False
tok
=
AutoTokenizer
.
from_pretrained
(
src
,
use_fast
=
True
,
trust_remote_code
=
_trust
)
llm
.
_kvprune_compactor_hf_tokenizer
=
tok
return
tok
def
_prompt_to_compactor_input
(
prompt
:
Any
)
->
str
|
list
[
int
]:
if
isinstance
(
prompt
,
str
):
return
prompt
# Decoder-only `list[int]` token ids (see `vllm.inputs.PromptType`).
if
isinstance
(
prompt
,
list
):
if
not
prompt
:
raise
TypeError
(
"Empty token-id prompt is not supported for compactor path."
)
if
all
(
isinstance
(
t
,
int
)
for
t
in
prompt
):
return
list
(
prompt
)
if
isinstance
(
prompt
,
dict
):
if
"prompt_token_ids"
in
prompt
:
ids
=
prompt
[
"prompt_token_ids"
]
return
list
(
ids
)
if
not
isinstance
(
ids
,
list
)
else
ids
p
=
prompt
.
get
(
"prompt"
)
if
isinstance
(
p
,
str
):
return
p
raise
TypeError
(
f
"Unsupported prompt type for compactor path:
{
type
(
prompt
)
}
. "
"Use str, list[int] token ids, or dict with 'prompt_token_ids' or 'prompt'."
)
def
_prompt_to_token_ids_for_tp
(
llm
:
Any
,
prompt
:
Any
)
->
list
[
int
]:
"""Driver-side token ids for the TP collective path (same tokenizer as vLLM ``LLM``)."""
comp_in
=
_prompt_to_compactor_input
(
prompt
)
if
isinstance
(
comp_in
,
str
):
return
llm
.
get_tokenizer
().
encode
(
comp_in
)
return
list
(
comp_in
)
def
_compressed_generate_tp_collective
(
llm
:
Any
,
plist
:
list
[
Any
],
sps
:
list
[
SamplingParams
],
comps
:
list
[
CompressionParams
],
)
->
list
[
RequestOutput
]:
"""TP>1: run compactor on each worker via ``collective_rpc`` (all ranks)."""
vc
=
llm
.
llm_engine
.
vllm_config
pc
=
vc
.
parallel_config
if
pc
.
pipeline_parallel_size
!=
1
or
pc
.
data_parallel_size
!=
1
:
raise
NotImplementedError
(
"KV-prune TP compression requires pipeline_parallel_size=1 and "
f
"data_parallel_size=1 (got PP=
{
pc
.
pipeline_parallel_size
}
, "
f
"DP=
{
pc
.
data_parallel_size
}
)."
)
hf
=
vc
.
model_config
.
hf_config
tok
=
llm
.
get_tokenizer
()
eos_token_ids
=
_infer_stop_token_ids
(
tok
,
hf
)
prompt_token_ids
=
[
_prompt_to_token_ids_for_tp
(
llm
,
p
)
for
p
in
plist
]
max_len
=
int
(
vc
.
model_config
.
max_model_len
)
for
i
,
ids
in
enumerate
(
prompt_token_ids
):
if
len
(
ids
)
>
max_len
:
raise
ValueError
(
f
"KV-prune TP compressed generate: prompt
{
i
}
length
{
len
(
ids
)
}
"
f
"exceeds max_model_len (
{
max_len
}
). Shorten the prompt or raise "
"max_model_len when constructing LLM()."
)
# Payload must be picklable for multiproc/Ray RPC: do not pass multiprocessing
# synchronization primitives (workers are separate processes).
payload
:
dict
[
str
,
Any
]
=
{
"eos_token_ids"
:
eos_token_ids
,
"prompt_token_ids"
:
prompt_token_ids
,
"sampling_params"
:
[
{
"temperature"
:
float
(
sp
.
temperature
),
"max_new_tokens"
:
int
(
sp
.
max_tokens
if
sp
.
max_tokens
is
not
None
else
16
),
}
for
sp
in
sps
],
"compression_params"
:
[
{
"compression_ratio"
:
float
(
c
.
compression_ratio
),
"compression_method"
:
str
(
c
.
compression_method
),
"protected_first_tokens"
:
int
(
c
.
protected_first_tokens
),
"protected_last_tokens"
:
int
(
c
.
protected_last_tokens
),
}
for
c
in
comps
],
}
_maybe_release_v1_kv_for_compactor
(
llm
)
try
:
results
=
llm
.
llm_engine
.
collective_rpc
(
"kvprune_v1_compressed_generate"
,
args
=
(
payload
,),
)
except
RuntimeError
as
e
:
if
"cancelled"
in
str
(
e
).
lower
():
raise
RuntimeError
(
"collective_rpc was cancelled (a GPU worker likely crashed). "
"Scroll up for the first worker traceback — often NCCL/CUDA before "
"TCPStore/Broken pipe on the driver."
)
from
e
raise
master
:
dict
[
str
,
Any
]
|
None
=
None
for
r
in
results
:
if
isinstance
(
r
,
dict
)
and
r
.
get
(
"tensor_parallel_rank"
)
==
0
:
master
=
r
break
if
master
is
None
:
raise
RuntimeError
(
"collective_rpc did not return a dict from tensor parallel rank 0."
)
return
_tp_payload_to_request_outputs
(
llm
,
master
)
def
_tp_payload_to_request_outputs
(
llm
:
Any
,
master
:
dict
[
str
,
Any
])
->
list
[
RequestOutput
]:
tok
=
llm
.
get_tokenizer
()
out
:
list
[
RequestOutput
]
=
[]
pids_list
=
master
[
"prompt_token_ids"
]
cids_list
=
master
[
"completion_token_ids"
]
for
i
,
(
pids
,
cids
)
in
enumerate
(
zip
(
pids_list
,
cids_list
)):
text
=
tok
.
decode
(
cids
,
skip_special_tokens
=
True
)
co
=
CompletionOutput
(
index
=
0
,
text
=
text
,
token_ids
=
list
(
cids
),
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
"stop"
,
)
ro
=
RequestOutput
(
request_id
=
f
"kvprune-tp-
{
i
}
"
,
prompt
=
None
,
prompt_token_ids
=
list
(
pids
),
prompt_logprobs
=
None
,
outputs
=
[
co
],
finished
=
True
,
)
out
.
append
(
ro
)
return
out
def
_ensure_compactor_engine
(
llm
:
Any
)
->
LLMEngine
:
if
llm
.
_kvprune_compactor_engine
is
None
:
pc
=
llm
.
llm_engine
.
vllm_config
.
parallel_config
if
pc
.
tensor_parallel_size
!=
1
:
raise
ValueError
(
"KV-pruning compactor path requires tensor_parallel_size=1 "
"for shared weights."
)
llm
.
_kvprune_compactor_engine
=
create_compactor_engine_with_shared_weights
(
llm
)
logger
.
info
(
"Initialized compactor LLMEngine with weights shared from vLLM."
)
return
llm
.
_kvprune_compactor_engine
def
try_compressed_generate
(
llm
:
Any
,
prompts
:
Any
,
sampling_params
:
SamplingParams
|
Sequence
[
SamplingParams
]
|
None
,
*
,
compression
:
CompressionParams
|
Sequence
[
CompressionParams
]
|
None
,
use_tqdm
:
bool
|
Callable
[...,
tqdm
]
=
True
,
lora_request
:
Any
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
list
[
RequestOutput
]
|
None
:
"""Return completions on the compactor engine, or ``None`` to use normal v1.
``lora_request`` / ``priority`` / ``tokenization_kwargs`` are accepted for API
parity with :meth:`~vllm.entrypoints.llm.LLM.generate` but are not passed to the
compactor engine yet.
"""
del
lora_request
,
priority
,
tokenization_kwargs
,
use_tqdm
plist
=
_normalize_prompt_list
(
prompts
)
sps
=
_normalize_sampling_params
(
sampling_params
,
len
(
plist
))
comps
=
_normalize_compression_params
(
compression
,
len
(
plist
))
pc
=
llm
.
llm_engine
.
vllm_config
.
parallel_config
# TP>1: every worker must run the same collective_rpc session. If all
# compression_ratio >= 1, the old code returned None and only the driver ran
# v1 _run_engine — other ranks never joined a matching collective, which can
# deadlock NCCL / leave workers unsynchronized (hang at "Processed prompts:").
if
pc
.
tensor_parallel_size
>
1
:
if
not
_should_use_kvprune_compactor_path
(
comps
):
comps
=
[
CompressionParams
(
compression_ratio
=
1.0
)
for
_
in
plist
]
elif
not
_should_use_kvprune_compactor_path
(
comps
):
return
None
v1_eager
=
bool
(
getattr
(
llm
.
llm_engine
.
vllm_config
.
model_config
,
"enforce_eager"
,
False
)
)
if
not
v1_eager
:
logger
.
warning
(
"KV-prune compression: v1 CUDA graphs are still enabled on this LLM. "
"The compactor does not reuse v1 graphs; capture wastes VRAM. "
"Set kvprune_compression=True, enforce_eager=True, or "
"VLLM_KVPRUNE_COMPRESSION_DEFAULT=1 before import vllm."
)
if
pc
.
tensor_parallel_size
>
1
:
return
_compressed_generate_tp_collective
(
llm
,
plist
,
sps
,
comps
)
ensure_inprocess_engine_for_weight_sharing
()
if
llm
.
_kvprune_compactor_engine
is
None
:
_maybe_release_v1_kv_for_compactor
(
llm
)
engine
=
_ensure_compactor_engine
(
llm
)
comp_sp
=
[
_to_compactor_sampling
(
sp
)
for
sp
in
sps
]
seq_c
=
[
_to_sequence_compression
(
c
)
for
c
in
comps
]
batch_c
=
_batch_compression_from_comps
(
comps
)
comp_in
=
[
_prompt_to_compactor_input
(
p
)
for
p
in
plist
]
_
,
seqs
=
engine
.
generate
(
comp_in
,
sampling_params
=
comp_sp
,
batch_compression_params
=
batch_c
,
per_sequence_compression_params
=
seq_c
,
return_sequences
=
True
,
)
return
_sequences_to_request_outputs
(
seqs
,
engine
)
def
_sequences_to_request_outputs
(
seqs
:
list
[
Any
],
engine
:
LLMEngine
)
->
list
[
RequestOutput
]:
tok
=
engine
.
tokenizer
out
:
list
[
RequestOutput
]
=
[]
for
i
,
seq
in
enumerate
(
seqs
):
text
=
tok
.
decode
(
seq
.
completion_token_ids
,
skip_special_tokens
=
True
)
# If every emitted id is “special” (e.g. EOS / chat boundary), the stripped
# string is empty while ``completion_token_ids`` is non-empty — avoid
# presenting a blank answer so users can see boundary tokens / debug.
if
not
text
.
strip
()
and
seq
.
completion_token_ids
:
text
=
tok
.
decode
(
seq
.
completion_token_ids
,
skip_special_tokens
=
False
)
co
=
CompletionOutput
(
index
=
0
,
text
=
text
,
token_ids
=
list
(
seq
.
completion_token_ids
),
cumulative_logprob
=
None
,
logprobs
=
None
,
finish_reason
=
"stop"
,
)
ro
=
RequestOutput
(
request_id
=
f
"kvprune-
{
i
}
"
,
prompt
=
None
,
prompt_token_ids
=
list
(
seq
.
prompt_token_ids
),
prompt_logprobs
=
None
,
outputs
=
[
co
],
finished
=
True
,
)
out
.
append
(
ro
)
return
out
vllm/kvprune_legacy_save/integration/compression_params.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Per-request KV compression for :meth:`vllm.LLM.generate` (``compression=`` kwarg)."""
from
__future__
import
annotations
from
dataclasses
import
dataclass
@
dataclass
class
CompressionParams
:
"""Per-prompt compression intent for :meth:`vllm.LLM.generate`.
If **any** prompt in the batch has ``compression_ratio < 1.0``, the **whole** batch
is run on the compactor ``LLMEngine`` (same stack as standalone compactor-vllm:
``PagedKVCache`` + pruning kernels). If all prompts have ``compression_ratio >= 1.0``,
the batch stays on standard vLLM.
``compression_method`` follows :mod:`vllm.kvprune.core.compression_bridge` aliases:
``none``, ``criticaladakv``, ``compactor``, ``snapkv`` (ignored when
``compression_ratio`` is effectively 1).
``protected_*`` map to compactor :class:`~vllm.kvprune.compression.compression_config.SequenceCompressionParams`
(defaults match standalone compactor-vllm-style usage).
"""
compression_ratio
:
float
=
1.0
compression_method
:
str
=
"compactor"
protected_first_tokens
:
int
=
16
protected_last_tokens
:
int
=
64
def
__post_init__
(
self
)
->
None
:
if
not
0.0
<
self
.
compression_ratio
<=
1.0
:
raise
ValueError
(
f
"compression_ratio must be in (0, 1], got
{
self
.
compression_ratio
}
"
)
self
.
compression_method
=
(
self
.
compression_method
or
"compactor"
).
strip
().
lower
()
from
vllm.kvprune.core.compression_bridge
import
VALID_ALIASES_FOR_SAMPLING
if
self
.
compression_method
not
in
VALID_ALIASES_FOR_SAMPLING
:
raise
ValueError
(
f
"compression_method must be one of
{
sorted
(
VALID_ALIASES_FOR_SAMPLING
)
}
, "
f
"got
{
self
.
compression_method
!
r
}
"
)
if
self
.
compression_ratio
>=
1.0
-
1e-9
:
self
.
compression_method
=
"none"
elif
self
.
compression_method
==
"none"
:
raise
ValueError
(
"When compression_ratio < 1.0, compression_method cannot be 'none'."
)
Prev
1
…
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment