Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f81ce56b
"tests/vscode:/vscode.git/clone" did not exist on "11ef7a611ec015523301930a25422cf68216b5c4"
Commit
f81ce56b
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.1
parent
2b7160c6
Changes
237
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5852 deletions
+0
-5852
vllm/compactor-vllm/src/compactor_vllm/attention/sparse_varlen_kernel.py
...vllm/src/compactor_vllm/attention/sparse_varlen_kernel.py
+0
-526
vllm/compactor-vllm/src/compactor_vllm/benchmark/__init__.py
vllm/compactor-vllm/src/compactor_vllm/benchmark/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/compression/__init__.py
...compactor-vllm/src/compactor_vllm/compression/__init__.py
+0
-41
vllm/compactor-vllm/src/compactor_vllm/compression/common.py
vllm/compactor-vllm/src/compactor_vllm/compression/common.py
+0
-243
vllm/compactor-vllm/src/compactor_vllm/compression/compactor.py
...ompactor-vllm/src/compactor_vllm/compression/compactor.py
+0
-704
vllm/compactor-vllm/src/compactor_vllm/compression/compactor_origin.py
...r-vllm/src/compactor_vllm/compression/compactor_origin.py
+0
-600
vllm/compactor-vllm/src/compactor_vllm/compression/compression_config.py
...vllm/src/compactor_vllm/compression/compression_config.py
+0
-45
vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv-cursor.py
...-vllm/src/compactor_vllm/compression/criticalkv-cursor.py
+0
-459
vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv.py
...mpactor-vllm/src/compactor_vllm/compression/criticalkv.py
+0
-451
vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv_origin.py
...-vllm/src/compactor_vllm/compression/criticalkv_origin.py
+0
-502
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv.py
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv.py
+0
-546
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py
...ctor-vllm/src/compactor_vllm/compression/snapkv_origin.py
+0
-449
vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/config/constants.py
vllm/compactor-vllm/src/compactor_vllm/config/constants.py
+0
-5
vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py
...compactor-vllm/src/compactor_vllm/config/engine_config.py
+0
-100
vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py
...mpactor-vllm/src/compactor_vllm/config/sampling_params.py
+0
-11
vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
+0
-404
vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py
.../compactor-vllm/src/compactor_vllm/core/memory_manager.py
+0
-182
vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
+0
-584
No files found.
vllm/compactor-vllm/src/compactor_vllm/attention/sparse_varlen_kernel.py
deleted
100644 → 0
View file @
2b7160c6
import
logging
import
math
import
torch
import
triton
import
triton.language
as
tl
from
compactor_vllm.utils.triton_compat
import
(
autotune
as
triton_autotune
,
cuda_capability_geq
,
maybe_set_allocator
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
causal_sparse_varlen_with_cache
(
q
,
k
,
v
,
k_cache
,
v_cache
,
seq_lens_bh
,
global_page_table
,
batch_mapping
,
cu_seqlens_q
,
max_seqlen_q
:
int
,
max_seqlen_k_cache
:
int
,
HKV
:
int
,
PAGE_SIZE
:
int
,
sm_scale
=
None
,
):
"""
Causal prefill attention over a paged KV cache plus a block of newly
appended tokens in a packed batch format.
This function wraps the Triton kernel
``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
a batch of variable-length sequences, where:
• Past keys/values are stored in a paged global KV cache
(``k_cache``, ``v_cache``) with a (per-layer) page table.
• New tokens for this step are given as K/V blocks
(``k``, ``v``), together with a packed query block ``q``.
• The result is equivalent to applying causal attention over the
concatenation of:
[ cached KV prefix || (K_app, V_app) for this step ]
for each sequence in the batch.
Grouped-query attention (GQA / MQA) is supported by allowing more query
heads than KV heads: ``HQ`` must be divisible by ``HKV``.
Args:
:param q:
Query tensor of shape ``[N, HQ, D]`` (float16 / bfloat16/float32).
``N`` is the total number of new tokens across the batch
(i.e. ``N = sum_b seqlen_q[b]``), packed according to
``cu_seqlens_q``. ``HQ`` is the number of query heads, ``D`` the
head dimension (must be a power of two).
:param k:
New key tensor of shape ``[N, HKV, D]`` for the same tokens as
``q``. These are the K values appended to the cache for this
prefill step.
:param v:
New value tensor of shape ``[N, HKV, D]`` for the same tokens as
``q``.
:param k_cache:
Global key cache backing buffer of shape ``[CACHE_SIZE, D]``.
Keys for all cached tokens and heads are stored here; the mapping
from (batch, head, token index) to a row in this buffer is
given by ``global_page_table``.
:param v_cache:
Global value cache of shape ``[CACHE_SIZE, D]``. Must have the
same layout as ``k_cache`` (same ``CACHE_SIZE`` and ``D``).
:param seq_lens_bh:
Tensor of shape ``[B, HKV]`` (int32) giving, for each local batch
index and KV head, the number of cached tokens already present
in the paged KV cache before this prefill step.
:param global_page_table:
Tensor of shape ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32)
mapping ``(true_batch_idx, kv_head, logical_page)`` to a physical
page id in the global KV cache. A physical page id `p` refers to
the slice:
``k_cache[p * PAGE_SIZE : (p + 1) * PAGE_SIZE]``.
:param batch_mapping:
Tensor of shape ``[B]`` (int16 / int32) mapping the local batch
index used in this kernel launch to the global batch index used
to index ``global_page_table``. This allows the same global cache
to be shared across multiple microbatches.
:param cu_seqlens_q:
Tensor of shape ``[B + 1]`` (int32) with cumulative sequence
lengths for the *new* tokens (q/k/v) in packed form. For batch
element ``b``:
``seqlen_q[b] = cu_seqlens_q[b + 1] - cu_seqlens_q[b]``.
The total number of tokens satisfies
``N = cu_seqlens_q[-1]``.
:param max_seqlen_q:
Maximum new query sequence length across the batch, i.e.
``max_b seqlen_q[b]``.
:param max_seqlen_k_cache:
Maximum cached sequence length across (batch, KV head), i.e.
``max_{b,h} seq_lens_bh[b, h]``.
:param HKV:
Number of KV heads. Must divide ``HQ``.
:param PAGE_SIZE:
Number of tokens stored per physical page in the paged KV cache.
``CACHE_SIZE`` must be divisible by ``PAGE_SIZE``.
:param sm_scale:
Optional scaling factor applied to the attention logits before
softmax. If ``None``, defaults to ``1.0 / sqrt(D)``.
:returns torch.Tensor:
Attention output of shape ``[N, HQ, D]``, with the same dtype and
device as ``q``. The output is laid out in the same packed
varlen format as the input queries, i.e. the first
``seqlen_q[0]`` rows correspond to batch 0, the next
``seqlen_q[1]`` rows to batch 1, etc.
"""
assert
q
.
ndim
==
3
,
"q should be [N, HQ, D]"
N
,
HQ
,
D
=
q
.
shape
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be power of two"
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
>
0
assert
HQ
%
HKV
==
0
,
"Number of query heads must divide number of keys heads"
H_g
=
HQ
//
HKV
# view Q as [HKV, N, QUERY_GROUP_SIZE, D]
out
=
torch
.
empty_like
(
q
)
q
=
q
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
out
=
out
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
# K_app/V_app: [N, HKV, D] -> [HKV, N, D]
k_app
=
k
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
v_app
=
v
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
cu_seqlens_q
=
cu_seqlens_q
.
to
(
dtype
=
torch
.
int32
,
device
=
q
.
device
)
seq_lens_bh
=
seq_lens_bh
.
to
(
dtype
=
torch
.
int32
,
device
=
q
.
device
)
batch_mapping
=
batch_mapping
.
to
(
dtype
=
torch
.
int16
,
device
=
q
.
device
)
N_LOGICAL_PAGES_MAX
=
global_page_table
.
shape
[
-
1
]
CACHE_SIZE
=
k_cache
.
shape
[
0
]
assert
v_cache
.
shape
[
0
]
==
CACHE_SIZE
assert
k_cache
.
shape
[
1
]
==
D
and
v_cache
.
shape
[
1
]
==
D
assert
PAGE_SIZE
>
0
and
CACHE_SIZE
%
PAGE_SIZE
==
0
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
# strides for Q [G, N, QUERY_GROUP_SIZE, D]
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
=
q
.
stride
()
STRIDE_KC
,
STRIDE_VC
=
k_cache
.
stride
(
0
),
v_cache
.
stride
(
0
)
# [G, N, D]
STRIDE_KA_G
,
STRIDE_KA_N
,
STRIDE_KA_D
=
k_app
.
stride
()
STRIDE_VA_G
,
STRIDE_VA_N
,
STRIDE_VA_D
=
v_app
.
stride
()
# OUT [G, N, QUERY_GROUP_SIZE, D]
STRIDE_OUT_G
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
STRIDE_OUT_D
=
out
.
stride
()
# launch grid
maybe_set_allocator
(
lambda
size
,
align
,
_
:
torch
.
empty
(
size
,
dtype
=
torch
.
int8
,
device
=
q
.
device
)
)
assert
STRIDE_KA_D
==
STRIDE_VA_D
==
STRIDE_Q_D
==
STRIDE_OUT_D
==
1
,
(
"final dimension must be contiguous"
)
def
grid
(
META
):
return
HKV
,
B
,
triton
.
cdiv
(
max_seqlen_q
,
META
[
"BLOCK_M"
])
# On a fresh batch, max_seqlen_k_cache==0 (no KV prefix yet). Passing
# `triton.next_power_of_2(0)` into autotune constexpr keys breaks
# kernel selection / tuning and can yield garbage outputs.
_k_max_autotune
=
max
(
int
(
max_seqlen_k_cache
),
1
)
AUTOTUNE_MAX_Q_LEN
=
triton
.
next_power_of_2
(
max_seqlen_q
)
AUTOTUNE_MAX_K_LEN
=
triton
.
next_power_of_2
(
_k_max_autotune
)
_causal_head_sparse_varlen_with_cache
[
grid
](
Q
=
q
,
K_cache
=
k_cache
,
V_cache
=
v_cache
,
K_app
=
k_app
,
V_app
=
v_app
,
cu_seqlens_qk
=
cu_seqlens_q
,
seq_lens_bh
=
seq_lens_bh
,
page_table
=
global_page_table
,
batch_mapping
=
batch_mapping
,
OUT
=
out
,
HKV
=
HKV
,
QUERY_GROUP_SIZE
=
H_g
,
PAGE_SIZE
=
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
STRIDE_Q_G
=
STRIDE_Q_G
,
STRIDE_Q_N
=
STRIDE_Q_N
,
STRIDE_Q_H
=
STRIDE_Q_H
,
STRIDE_KC
=
STRIDE_KC
,
STRIDE_VC
=
STRIDE_VC
,
STRIDE_KA_G
=
STRIDE_KA_G
,
STRIDE_KA_N
=
STRIDE_KA_N
,
STRIDE_VA_G
=
STRIDE_VA_G
,
STRIDE_VA_N
=
STRIDE_VA_N
,
STRIDE_OUT_G
=
STRIDE_OUT_G
,
STRIDE_OUT_N
=
STRIDE_OUT_N
,
STRIDE_OUT_H
=
STRIDE_OUT_H
,
sm_scale
=
sm_scale
,
D
=
D
,
AUTOTUNE_MAX_Q_LEN
=
AUTOTUNE_MAX_Q_LEN
,
AUTOTUNE_MAX_K_LEN
=
AUTOTUNE_MAX_K_LEN
,
)
return
out
.
permute
(
1
,
0
,
2
,
3
).
view
(
N
,
HQ
,
D
)
# already contiguous
autotune_configs_cc9
=
[
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
64
,
"WARPSPEC"
:
True
},
num_warps
=
16
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
64
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
False
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
4
),
]
autotune_configs_cc8
=
[
triton
.
Config
(
{
"BLOCK_N"
:
BN
,
"BLOCK_M"
:
BM
,
"WARPSPEC"
:
True
},
num_warps
=
w
,
num_stages
=
s
)
for
BN
in
[
16
,
32
]
for
BM
in
[
64
]
for
w
in
[
4
,
8
]
for
s
in
[
2
,
3
]
]
def
prune_invalid_configs
(
configs
,
_
,
**
kwargs
):
return
[
conf
for
conf
in
configs
if
not
(
conf
.
kwargs
.
get
(
"BLOCK_N"
)
==
32
and
conf
.
kwargs
.
get
(
"num_stages"
)
==
4
)
]
def
get_autotune_configs
():
if
cuda_capability_geq
(
9
,
0
):
return
autotune_configs_cc9
else
:
return
autotune_configs_cc8
@
triton_autotune
(
configs
=
get_autotune_configs
(),
key
=
[
"HKV"
,
"QUERY_GROUP_SIZE"
,
"D"
,
"PAGE_SIZE"
,
"AUTOTUNE_MAX_K_LEN"
,
"AUTOTUNE_MAX_Q_LEN"
,
],
cache_results
=
True
,
)
@
triton
.
jit
def
_causal_head_sparse_varlen_with_cache
(
Q
,
# [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
K_cache
,
V_cache
,
# [CACHE_SIZE, D]
K_app
,
V_app
,
# [HKV, N, D]
cu_seqlens_qk
,
# [B+1]
seq_lens_bh
,
# [B, HKV]
page_table
,
# [B_total, HKV, N_LOGICAL_PAGES_MAX]
batch_mapping
,
# [B], maps local b -> global batch index
OUT
,
# [HKV, N, QUERY_GROUP_SIZE, D]
#
HKV
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
,
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_KC
,
STRIDE_VC
,
STRIDE_KA_G
,
STRIDE_KA_N
,
STRIDE_VA_G
,
STRIDE_VA_N
,
STRIDE_OUT_G
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
#
D
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
AUTOTUNE_MAX_Q_LEN
:
tl
.
constexpr
,
# used for autotune key
AUTOTUNE_MAX_K_LEN
:
tl
.
constexpr
,
# used for autotune key
):
TOTAL_N_QUERIES
:
tl
.
constexpr
=
BLOCK_M
*
QUERY_GROUP_SIZE
pid_g
=
tl
.
program_id
(
0
)
# kv_head id in [0, HKV)
pid_b
=
tl
.
program_id
(
1
)
# batch id
pid_m
=
tl
.
program_id
(
2
)
# query-tile id within batch
# batch segment [qb, qe) in N
off_b
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
)
off_b1
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
+
1
)
seq_len_append
=
off_b1
-
off_b
q_start
=
off_b
+
pid_m
*
BLOCK_M
q_end
=
tl
.
minimum
(
q_start
+
BLOCK_M
,
off_b1
)
# number of queries in this tile for this batch
M
=
q_end
-
q_start
if
M
<=
0
:
return
# cached length for (b, kv_head=pid_g)
L_cache
=
tl
.
load
(
seq_lens_bh
+
pid_b
*
HKV
+
pid_g
)
# row indices flattened over [QUERY_GROUP_SIZE, M]
offs_row
=
tl
.
arange
(
0
,
TOTAL_N_QUERIES
)
row_m
=
offs_row
%
BLOCK_M
row_h
=
offs_row
//
BLOCK_M
# valid rows: only those with row_m < M
row_mask
=
row_m
<
M
# global query index per row
q_idx
=
q_start
+
row_m
offs_d
=
tl
.
arange
(
0
,
D
)
# Q tile: [TOTAL_N_QUERIES, D]
# Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
q_ptrs
=
(
Q
+
pid_g
*
STRIDE_Q_G
+
q_idx
[:,
None
]
*
STRIDE_Q_N
+
row_h
[:,
None
]
*
STRIDE_Q_H
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
e_max
=
tl
.
zeros
([
TOTAL_N_QUERIES
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
TOTAL_N_QUERIES
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
TOTAL_N_QUERIES
,
D
],
dtype
=
tl
.
float32
)
offs_block_n
=
tl
.
arange
(
0
,
BLOCK_N
)
qk_scale
=
sm_scale
*
1.44269504
# 1) attend over cachee K/V
if
L_cache
>
0
:
# map local (b) to global batch index
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
pt_base
=
(
mapped_b
*
HKV
+
pid_g
)
*
N_LOGICAL_PAGES_MAX
# iterate logical pages
num_lp
=
tl
.
cdiv
(
L_cache
,
PAGE_SIZE
)
for
lp
in
tl
.
range
(
0
,
num_lp
):
# can overflow in 32 bits so upcast
phys
=
tl
.
load
(
page_table
+
pt_base
+
lp
).
to
(
tl
.
int64
)
page_start
=
phys
*
PAGE_SIZE
# how many valid tokens in this page for this (b,g)
remain
=
L_cache
-
lp
*
PAGE_SIZE
page_len
=
tl
.
minimum
(
PAGE_SIZE
,
remain
)
# iterate over this page in BLOCK_N chunks
for
ks
in
tl
.
range
(
0
,
page_len
,
BLOCK_N
):
offs_n
=
ks
+
offs_block_n
mask_n
=
offs_n
<
page_len
key_idx
=
page_start
+
offs_n
k_ptrs
=
K_cache
+
key_idx
[:,
None
]
*
STRIDE_KC
+
offs_d
[
None
,
:]
k
=
tl
.
load
(
k_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
# [BN, D]
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
# [TOTAL_N_QUERIES, BN]
qk
=
tl
.
where
(
row_mask
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
# softmax update
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
v_ptrs
=
V_cache
+
key_idx
[:,
None
]
*
STRIDE_VC
+
offs_d
[
None
,
:]
v
=
tl
.
load
(
v_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
# [BN, D]
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# 2) attend over appended K_app/V_app (causal)
# appended tokens for batch b are in [off_b, off_b1)
# query tile is [q_start, q_end)
# for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
if
q_end
>
off_b
:
# exactly one appended token
if
seq_len_append
==
1
:
ka_ptrs
=
K_app
+
pid_g
*
STRIDE_KA_G
+
off_b
*
STRIDE_KA_N
+
offs_d
k
=
tl
.
load
(
ka_ptrs
)
# [D]
qk
=
tl
.
sum
(
q
*
k
[
None
,
:],
1
)
*
qk_scale
qk
=
tl
.
where
(
row_mask
,
qk
,
-
1.0e6
)
n_e_max
=
tl
.
maximum
(
e_max
,
qk
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
)
va_ptrs
=
V_app
+
pid_g
*
STRIDE_VA_G
+
off_b
*
STRIDE_VA_N
+
offs_d
v
=
tl
.
load
(
va_ptrs
)
# [D]
acc
=
acc
*
re_scale
[:,
None
]
+
p
[:,
None
]
*
v
[
None
,
:]
e_sum
=
e_sum
*
re_scale
+
p
else
:
# off-band: k in [off_b, q_start)
# for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
# so no causal mask needed.
off_band_start
=
off_b
off_band_end
=
q_start
if
off_band_end
>
off_band_start
:
for
ks
in
tl
.
range
(
off_band_start
,
off_band_end
,
BLOCK_N
):
offs_n
=
ks
+
offs_block_n
mask_n
=
offs_n
<
off_band_end
ka_ptrs
=
(
K_app
+
pid_g
*
STRIDE_KA_G
+
offs_n
[:,
None
]
*
STRIDE_KA_N
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
ka_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
row_mask
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
va_ptrs
=
(
V_app
+
pid_g
*
STRIDE_VA_G
+
offs_n
[:,
None
]
*
STRIDE_VA_N
+
offs_d
[
None
,
:]
)
v
=
tl
.
load
(
va_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# on-band remaining k
on_band_start
=
tl
.
maximum
(
q_start
,
off_b
)
if
on_band_start
<
q_end
:
for
ks
in
tl
.
range
(
on_band_start
,
q_end
,
BLOCK_N
):
offs_n
=
ks
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
q_end
ka_ptrs
=
(
K_app
+
pid_g
*
STRIDE_KA_G
+
offs_n
[:,
None
]
*
STRIDE_KA_N
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
ka_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
caus_mask
=
offs_n
[
None
,
:]
<=
q_idx
[:,
None
]
full_mask
=
row_mask
[:,
None
]
&
mask_n
[
None
,
:]
&
caus_mask
qk
=
tl
.
where
(
full_mask
,
qk
,
-
1.0e6
)
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
va_ptrs
=
(
V_app
+
pid_g
*
STRIDE_VA_G
+
offs_n
[:,
None
]
*
STRIDE_VA_N
+
offs_d
[
None
,
:]
)
v
=
tl
.
load
(
va_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# 3) write outputs
o
=
(
acc
/
e_sum
[:,
None
]).
to
(
q
.
dtype
)
out_ptrs
=
(
OUT
+
pid_g
*
STRIDE_OUT_G
+
q_idx
[:,
None
]
*
STRIDE_OUT_N
+
row_h
[:,
None
]
*
STRIDE_OUT_H
+
offs_d
[
None
,
:]
)
tl
.
store
(
out_ptrs
,
o
,
mask
=
row_mask
[:,
None
])
vllm/compactor-vllm/src/compactor_vllm/benchmark/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/compression/__init__.py
deleted
100644 → 0
View file @
2b7160c6
from
compactor_vllm.compression.common
import
(
BaseCompressionMethod
,
NoCompression
,
)
from
compactor_vllm.compression.criticalkv
import
CriticalAdaKVCompression
from
compactor_vllm.compression.compactor
import
CompactorCompression
from
compactor_vllm.compression.compression_config
import
(
BatchCompressionParams
,
CompressionMethod
,
SequenceCompressionParams
,
)
from
compactor_vllm.compression.snapkv
import
SnapKVCompression
COMPRESSION_REGISTRY
:
dict
[
CompressionMethod
,
type
[
BaseCompressionMethod
]]
=
{
CompressionMethod
.
CRITICALADAKV
:
CriticalAdaKVCompression
,
CompressionMethod
.
COMPACTOR
:
CompactorCompression
,
CompressionMethod
.
SNAPKV
:
SnapKVCompression
,
CompressionMethod
.
NONE
:
NoCompression
,
}
def
apply_prerope_compression
(
q
,
k
,
v
,
context
):
method
=
context
.
compression_context
.
compression_method
return
COMPRESSION_REGISTRY
[
method
].
pre_rope_scoring
(
q
,
k
,
v
,
context
=
context
)
def
apply_postrope_compression
(
q
,
k
,
v
,
prerope_scores
,
context
):
method
=
context
.
compression_context
.
compression_method
return
COMPRESSION_REGISTRY
[
method
].
post_rope_scoring
(
q
,
k
,
v
,
prerope_scores
,
context
=
context
)
__all__
=
[
"apply_prerope_compression"
,
"apply_postrope_compression"
,
"CompressionMethod"
,
"BatchCompressionParams"
,
"SequenceCompressionParams"
,
"COMPRESSION_REGISTRY"
]
vllm/compactor-vllm/src/compactor_vllm/compression/common.py
deleted
100644 → 0
View file @
2b7160c6
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
from
compactor_vllm.kv_cache.store_kv_cache
import
prefill_store_topk_kv
class
BaseCompressionMethod
(
ABC
):
"""
Abstract interface for KV cache compression methods.
A compression method is implemented as a pair of optional scoring phases
that run before and after rotary position embedding (RoPE) is applied:
1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
- refine / reweight the pre-RoPE scores, or
- compute potentially position-aware.
Concrete subclasses are expected to implement both
static methods and return a single tensor of scores (or ``None`` if the
phase is a no-op), which the caller can then feed into the shared
“scores → top-k indices → KV extraction” pipeline.
"""
@
staticmethod
@
abstractmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Compute per-token importance scores from pre-RoPE queries/keys.
Args:
:param q:
Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
compactor_vllm.utils.context.Context object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
A tensor of scores (e.g. per-token, per-head importance values)
to be passed to ``post_rope_scoring`` or directly into the
top-k selection step. If this phase is a no-op, implementations
should return ``None``. Shape ``[total_tokens, HKV]```.
"""
pass
@
staticmethod
@
abstractmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
Optional
[
torch
.
Tensor
],
context
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Compute or refine importance scores from post-RoPE queries/keys.
This method is called after rotary embeddings have been applied. It can
optionally use both the post-RoPE Q/K and any scores produced by
``pre_rope_scoring`` to produce final scores used for token selection.
Common patterns include:
* Using ``pre_rope_scores`` as a base signal and applying a
position-aware correction.
* Only computing scores that depend on absolute or relative positions.
* Simply passing through ``pre_rope_scores`` unchanged.
Args:
:param q:
Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param pre_rope_scores:
Optional scores returned by ``pre_rope_scoring``. May be
``None`` if the pre-RoPE phase returned None.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
compactor_vllm.utils.context.Context object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
Final importance scores to be consumed by the compression
pipeline (for top-k token selection). If this phase is a
no-op, implementations may return ``pre_rope_scores``. If
None is returned, no compression will be applied.
"""
pass
class
NoCompression
(
BaseCompressionMethod
):
"""
Trivial compression method that disables KV cache compression.
"""
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
return
pre_rope_scores
def
extract_and_store_top_kv
(
scores
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_k_len
:
int
,
top_k
:
int
,
H
:
int
,
new_keys
:
torch
.
Tensor
,
# [N_total, H, D]
new_vals
:
torch
.
Tensor
,
# [N_total, H, D]
num_tokens_to_retain
:
torch
.
Tensor
,
# [B] int32
page_table
:
torch
.
Tensor
,
# [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local -> true batch rows)
bh_lens
:
torch
.
Tensor
,
# [B, H] int32 (contiguous), UPDATED atomically
k_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
v_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE
:
int
,
PAD_TO_PAGE_SIZE
:
bool
=
True
,
K_TILE
:
int
=
16
,
padding
:
float
=
-
float
(
"inf"
),
):
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
indices_topk
=
scores_to_retain_indices
(
scores
,
cu_seqlens_k
=
cu_seqlens_k
,
max_k_len
=
max_k_len
,
top_k
=
top_k
,
H
=
H
,
padding
=
padding
,
)
prefill_store_topk_kv
(
new_keys
=
new_keys
,
new_vals
=
new_vals
,
indices_topk
=
indices_topk
,
num_tokens_to_retain
=
num_tokens_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
bh_lens
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
cu_seqlens_k
=
cu_seqlens_k
,
PAGE_SIZE
=
PAGE_SIZE
,
PAD_TO_PAGE_SIZE
=
PAD_TO_PAGE_SIZE
,
K_TILE
=
K_TILE
,
)
def
scores_to_retain_indices
(
scores
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_k_len
:
int
,
top_k
:
int
,
H
:
int
,
padding
:
float
=
-
float
(
"inf"
),
)
->
torch
.
Tensor
:
"""
Select global top-k token–head indices per sequence from packed scores.
This helper takes per-token, per-head scores in packed varlen form and
returns, for each batch element, the indices of the top-k (token, head)
pairs in the flattened global layout.
Inputs are assumed to follow the usual packed varlen convention:
• ``scores`` is laid out as ``[N_total, H]``, where:
``N_total = sum_b seqlen_k[b]``
and ``HKV`` is the number of KV heads.
• ``cu_seqlens_k`` is ``[B + 1]`` (int32), giving cumulative lengths
for the keys per batch:
``seqlen_k[b] = cu_seqlens_k[b + 1] - cu_seqlens_k[b]``.
• ``max_k_len`` is an upper bound on ``seqlen_k[b]`` across the batch.
The function pads each sequence to length ``max_k_len`` with ``padding``
(default: ``-inf``), flattens the per-sequence scores into shape
``[B, max_k_len * H]``, and runs a per-batch top-k. The returned indices
are shifted so that they directly index into the flattened global
score layout of shape ``[N_total * H]``:
global_index = (token_global_offset * H) + head_index
Args:
:param scores:
Tensor of shape ``[N_total, HKV]`` containing scores for each
(token, head) pair in packed varlen format.
:param cu_seqlens_k:
Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
lengths for each batch element. The total number of tokens
satisfies ``N_total = cu_seqlens_k[-1]``.
:param max_k_len:
Maximum key sequence length across the batch (i.e.
``max_b seqlen_k[b]``). Used to allocate the padded buffer.
:param top_k:
Number of (token, head) entries to retain **per batch element**.
If ``top_k > max_k_len * HKV``, it is clamped to ``max_k_len * HKV``.
:param H:
Number of key heads; must match ``scores.shape[1]``.
:param padding:
Padding value used when extending sequences shorter than
``max_k_len``. Defaults to ``-inf``, so that padded positions are
never selected in the top-k.
Returns:
:return torch.Tensor:
Tensor of shape ``[B, k_eff]`` (int64) where
``k_eff = min(top_k, max_k_len * H)``. Each entry is a global
index into the flattened score array of shape ``[N_total * H]``
(i.e. scores viewed as ``scores.view(-1)``),
"""
# idea: pad and then select top-k.
B
,
device
=
cu_seqlens_k
.
numel
()
-
1
,
scores
.
device
padded
=
torch
.
full
(
(
B
,
max_k_len
,
H
),
fill_value
=
padding
,
dtype
=
scores
.
dtype
,
device
=
device
)
for
b
in
range
(
B
):
s
,
e
=
int
(
cu_seqlens_k
[
b
]),
int
(
cu_seqlens_k
[
b
+
1
])
padded
[
b
,
:
e
-
s
,
:].
copy_
(
scores
[
s
:
e
,
:])
flat
=
padded
.
view
(
B
,
max_k_len
*
H
)
idx
=
torch
.
topk
(
flat
,
k
=
min
(
top_k
,
max_k_len
*
H
),
dim
=
1
,
largest
=
True
,
sorted
=
True
).
indices
return
idx
+
(
cu_seqlens_k
[:
-
1
]
*
H
).
unsqueeze
(
-
1
)
vllm/compactor-vllm/src/compactor_vllm/compression/compactor.py
deleted
100644 → 0
View file @
2b7160c6
"""
Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
全局 z-score、blending 与首尾 sink pad)。
非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
"""
from
__future__
import
annotations
import
math
from
typing
import
List
,
Optional
import
torch
import
triton
import
triton.language
as
tl
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
def
resolve_kvpress_compactor_blending
(
compression_context
)
->
float
:
"""与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
if
compression_context
is
None
:
return
0.35
b
=
getattr
(
compression_context
,
"compactor_blending"
,
None
)
if
b
is
not
None
:
return
float
(
b
)
cr
=
getattr
(
compression_context
,
"compression_ratio"
,
None
)
if
cr
is
not
None
:
return
float
(
cr
)
return
0.35
class
CompactorCompression
(
BaseCompressionMethod
):
"""与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
chunk_size
:
int
=
256
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
return
maybe_execute_in_stream
(
kvpress_leverage_scores_packed
,
k
,
context
.
cu_seqlens_q
,
compression_context
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
blending
=
resolve_kvpress_compactor_blending
(
compression_context
)
return
maybe_execute_in_stream
(
kvpress_compactor_post_rope
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
pre_rope_scores
,
compression_context
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
blending
=
float
(
blending
),
STORE_STREAM
=
context
.
STORE_STREAM
,
)
# ---------------------------------------------------------------------------
# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
# ---------------------------------------------------------------------------
def
chol_with_jitter
(
G
:
torch
.
Tensor
,
jitter
:
float
=
0.0
,
max_tries
:
int
=
5
)
->
torch
.
Tensor
:
identity
=
torch
.
eye
(
G
.
shape
[
-
1
],
device
=
G
.
device
,
dtype
=
G
.
dtype
)
cur
=
float
(
jitter
)
for
_
in
range
(
max_tries
):
L
,
info
=
torch
.
linalg
.
cholesky_ex
(
G
+
cur
*
identity
,
upper
=
False
)
if
bool
((
info
==
0
).
all
()):
return
L
cur
=
max
(
1e-8
,
(
1e-2
if
cur
==
0.0
else
10.0
*
cur
))
raise
RuntimeError
(
f
"Cholesky failed after
{
max_tries
}
tries."
)
def
compute_leverage_scores_mid
(
key_states
:
torch
.
Tensor
,
sketch_dimension
:
int
)
->
torch
.
Tensor
:
"""
与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
返回 ``[L, H]``(未 z-score)。
维序与 kvpress 的 ``(B, H, S, D)`` 对齐:先变为 ``[1, H, L, D]``,在序列维(``dim=-2``)
上中心化,再与 ``Phi`` 为 ``(1, H, D, K)`` 的 batch 矩阵乘得到 ``[1, H, L, K]``。
"""
d
,
k
=
key_states
.
shape
[
-
1
],
sketch_dimension
device
,
dtype
=
key_states
.
device
,
key_states
.
dtype
H
=
key_states
.
shape
[
1
]
Phi
=
torch
.
randn
(
1
,
H
,
d
,
k
,
device
=
device
,
dtype
=
dtype
)
*
(
1.0
/
math
.
sqrt
(
k
))
# [L, H, d] -> [1, H, L, d],与 kvpress (B,H,S,d) 一致
X0
=
key_states
.
transpose
(
0
,
1
).
unsqueeze
(
0
).
contiguous
()
X
=
X0
-
X0
.
mean
(
dim
=-
2
,
keepdim
=
True
)
X
=
torch
.
matmul
(
X
,
Phi
).
to
(
torch
.
float32
)
XT
=
X
.
transpose
(
-
2
,
-
1
)
G
=
XT
@
X
L
=
chol_with_jitter
(
0.5
*
(
G
+
G
.
transpose
(
-
2
,
-
1
)),
jitter
=
1e-2
,
max_tries
=
5
)
inv_Xt
=
torch
.
cholesky_solve
(
XT
,
L
,
upper
=
False
)
scores
=
(
X
*
inv_Xt
.
transpose
(
-
2
,
-
1
)).
sum
(
dim
=-
1
).
clamp_min
(
0
)
# [1, H, L] -> [L, H]
return
scores
.
squeeze
(
0
).
transpose
(
0
,
1
).
contiguous
()
def
kvpress_leverage_scores_packed
(
key_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
compression_ctx
,
)
->
torch
.
Tensor
:
device
=
key_states
.
device
N
,
Hkv
,
_D
=
key_states
.
shape
sketch_dim
=
int
(
getattr
(
compression_ctx
,
"sketch_dimension"
,
48
))
sink_start
=
int
(
getattr
(
compression_ctx
,
"sink_size_start"
,
8
))
sink_end
=
int
(
getattr
(
compression_ctx
,
"sink_size_end"
,
4
))
out
=
torch
.
zeros
(
N
,
Hkv
,
device
=
device
,
dtype
=
torch
.
float32
)
mids_flat
:
list
[
torch
.
Tensor
]
=
[]
mid_ranges
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
k_mid
=
key_states
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
raw
=
compute_leverage_scores_mid
(
k_mid
,
sketch_dim
)
mids_flat
.
append
(
raw
.
reshape
(
-
1
))
mid_ranges
.
append
((
mid_start
,
mid_end
,
Hkv
))
if
not
mids_flat
:
return
out
flat
=
torch
.
cat
(
mids_flat
,
dim
=
0
)
z
=
_zscore_flat_f32_global
(
flat
)
offset
=
0
for
(
mid_start
,
mid_end
,
_Hkv
),
r
in
zip
(
mid_ranges
,
mids_flat
):
n
=
r
.
numel
()
seg
=
z
[
offset
:
offset
+
n
].
view
(
mid_end
-
mid_start
,
Hkv
)
out
[
mid_start
:
mid_end
,
:]
=
seg
offset
+=
n
return
out
# ---------------------------------------------------------------------------
# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
# ---------------------------------------------------------------------------
def
_non_causal_chunked_attn_pytorch
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""参考实现:与 kvpress 逐算子一致。"""
assert
chunk_size
>
0
and
q
.
shape
==
k
.
shape
L
,
H
,
d
=
q
.
shape
B
=
1
q
=
q
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
k
=
k
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
_B
,
H
,
S
,
_d
=
k
.
shape
S_pad
=
math
.
ceil
(
S
/
chunk_size
)
*
chunk_size
pad_len
=
S_pad
-
S
if
pad_len
>
0
:
q_padded
=
torch
.
cat
(
[
q
,
torch
.
zeros
(
B
,
H
,
pad_len
,
d
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)],
dim
=
2
)
k_padded
=
torch
.
cat
(
[
k
,
torch
.
zeros
(
B
,
H
,
pad_len
,
d
,
device
=
k
.
device
,
dtype
=
k
.
dtype
)],
dim
=
2
)
last_chunk_start
=
(
S
//
chunk_size
)
*
chunk_size
in_valid
=
torch
.
arange
(
last_chunk_start
,
S_pad
,
device
=
q
.
device
)
>=
S
query_mask
=
key_mask
=
in_valid
.
view
(
1
,
1
,
chunk_size
).
expand
(
B
,
H
,
chunk_size
)
else
:
q_padded
,
k_padded
=
q
,
k
last_chunk_start
=
((
S
-
1
)
//
chunk_size
)
*
chunk_size
in_valid
=
torch
.
arange
(
last_chunk_start
,
S_pad
,
device
=
q
.
device
)
>=
S
query_mask
=
key_mask
=
in_valid
.
view
(
1
,
1
,
chunk_size
).
expand
(
B
,
H
,
chunk_size
)
num_chunks
=
S_pad
//
chunk_size
q_chunks
=
q_padded
.
view
(
B
,
H
,
num_chunks
,
chunk_size
,
d
)
k_chunks
=
k_padded
.
view
(
B
,
H
,
num_chunks
,
chunk_size
,
d
)
dots
=
torch
.
matmul
(
q_chunks
,
k_chunks
.
transpose
(
-
2
,
-
1
))
dots
[:,
:,
-
1
].
masked_fill_
(
query_mask
.
unsqueeze
(
-
1
),
0
)
dots
[:,
:,
-
1
].
masked_fill_
(
key_mask
.
unsqueeze
(
-
2
),
-
1e-9
)
attn
=
torch
.
softmax
(
dots
.
to
(
torch
.
float32
),
dim
=-
1
)
out
=
attn
.
sum
(
dim
=-
2
).
view
(
B
,
H
,
S_pad
)[...,
:
S
]
return
out
.
squeeze
(
0
).
transpose
(
0
,
1
).
contiguous
()
@
triton
.
jit
def
_non_causal_chunk_row_kernel
(
Q_ptr
,
K_ptr
,
Out_ptr
,
stride_qh
,
stride_qs
,
stride_qd
,
stride_kh
,
stride_ks
,
stride_kd
,
stride_oh
,
stride_os
,
S
,
S_pad
,
num_chunks
,
CHUNK_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
ND
:
tl
.
constexpr
,
):
"""
每个 program:一个 head、一个 chunk、一条 query 行。
对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
"""
h
=
tl
.
program_id
(
0
)
c
=
tl
.
program_id
(
1
)
iq
=
tl
.
program_id
(
2
)
g_i
=
c
*
CHUNK_SIZE
+
iq
offs_j
=
tl
.
arange
(
0
,
CHUNK_SIZE
)
logits
=
tl
.
zeros
([
CHUNK_SIZE
],
dtype
=
tl
.
float32
)
for
db
in
range
(
ND
):
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
+
db
*
BLOCK_D
mask_d
=
offs_d
<
D
q_off
=
(
h
*
stride_qh
+
g_i
*
stride_qs
+
offs_d
*
stride_qd
)
qd
=
tl
.
load
(
Q_ptr
+
q_off
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
g_j
=
c
*
CHUNK_SIZE
+
offs_j
k_row_off
=
h
*
stride_kh
+
g_j
[:,
None
]
*
stride_ks
+
offs_d
[
None
,
:]
*
stride_kd
kj
=
tl
.
load
(
K_ptr
+
k_row_off
,
mask
=
mask_d
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
logits
+=
tl
.
sum
(
qd
[
None
,
:]
*
kj
,
axis
=
1
)
row_invalid
=
g_i
>=
S
g_j_all
=
c
*
CHUNK_SIZE
+
offs_j
col_invalid
=
g_j_all
>=
S
logits
=
tl
.
where
(
row_invalid
,
tl
.
zeros
([
CHUNK_SIZE
],
dtype
=
tl
.
float32
),
logits
)
logits
=
tl
.
where
(
row_invalid
,
logits
,
tl
.
where
(
col_invalid
,
tl
.
full
([
CHUNK_SIZE
],
-
1e-9
,
dtype
=
tl
.
float32
),
logits
),
)
m
=
tl
.
max
(
logits
)
logits
=
logits
-
m
exp_v
=
tl
.
exp
(
logits
)
denom
=
tl
.
sum
(
exp_v
)
p
=
exp_v
/
denom
out_base
=
h
*
stride_oh
+
g_j_all
*
stride_os
tl
.
atomic_add
(
Out_ptr
+
out_base
,
p
,
mask
=
g_j_all
<
S
)
def
_non_causal_chunked_attn_triton
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
assert
q
.
is_cuda
and
k
.
is_cuda
and
q
.
shape
==
k
.
shape
L
,
H
,
d
=
q
.
shape
assert
chunk_size
>
0
S_pad
=
math
.
ceil
(
L
/
chunk_size
)
*
chunk_size
pad_len
=
S_pad
-
L
if
pad_len
>
0
:
zq
=
torch
.
zeros
(
pad_len
,
H
,
d
,
device
=
q
.
device
,
dtype
=
q
.
dtype
,
requires_grad
=
False
)
zk
=
torch
.
zeros
(
pad_len
,
H
,
d
,
device
=
k
.
device
,
dtype
=
k
.
dtype
,
requires_grad
=
False
)
q
=
torch
.
cat
([
q
,
zq
],
dim
=
0
)
k
=
torch
.
cat
([
k
,
zk
],
dim
=
0
)
Q
=
q
.
transpose
(
0
,
1
).
contiguous
().
to
(
dtype
=
torch
.
float32
)
K
=
k
.
transpose
(
0
,
1
).
contiguous
().
to
(
dtype
=
torch
.
float32
)
num_chunks
=
S_pad
//
chunk_size
out_acc
=
torch
.
zeros
(
H
,
S_pad
,
device
=
q
.
device
,
dtype
=
torch
.
float32
)
S
=
int
(
L
)
grid
=
(
H
,
num_chunks
,
chunk_size
)
BLOCK_D
=
32
if
d
<=
128
else
64
ND
=
(
d
+
BLOCK_D
-
1
)
//
BLOCK_D
_non_causal_chunk_row_kernel
[
grid
](
Q
,
K
,
out_acc
,
Q
.
stride
(
0
),
Q
.
stride
(
1
),
Q
.
stride
(
2
),
K
.
stride
(
0
),
K
.
stride
(
1
),
K
.
stride
(
2
),
out_acc
.
stride
(
0
),
out_acc
.
stride
(
1
),
S
,
S_pad
,
int
(
num_chunks
),
CHUNK_SIZE
=
chunk_size
,
D
=
d
,
BLOCK_D
=
BLOCK_D
,
ND
=
ND
,
num_warps
=
4
,
)
return
out_acc
[:,
:
S
].
transpose
(
0
,
1
).
contiguous
()
def
non_causal_chunked_attn
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
if
q
.
is_cuda
and
k
.
is_cuda
:
return
_non_causal_chunked_attn_triton
(
q
,
k
,
chunk_size
)
return
_non_causal_chunked_attn_pytorch
(
q
,
k
,
chunk_size
)
# ---------------------------------------------------------------------------
# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
# ---------------------------------------------------------------------------
@
triton
.
jit
def
_mul_vnorm_avgpool3_kernel
(
A_ptr
,
V_ptr
,
OUT_ptr
,
stride_al
,
stride_ah
,
stride_vl
,
stride_vh
,
stride_vd
,
stride_ol
,
stride_oh
,
L
,
D
:
tl
.
constexpr
,
):
"""Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
l
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
offs
=
tl
.
arange
(
0
,
D
)
pos_m1
=
l
-
1
inb_m1
=
(
pos_m1
>=
0
)
&
(
pos_m1
<
L
)
ps_m1
=
tl
.
where
(
inb_m1
,
pos_m1
,
0
)
a_m1
=
tl
.
load
(
A_ptr
+
ps_m1
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_m1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v_m1
=
tl
.
load
(
V_ptr
+
ps_m1
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_m1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_m1
=
tl
.
where
(
inb_m1
,
a_m1
*
tl
.
sqrt
(
tl
.
sum
(
v_m1
*
v_m1
)),
0.0
)
inb_0
=
(
l
>=
0
)
&
(
l
<
L
)
ps0
=
tl
.
where
(
inb_0
,
l
,
0
)
a0
=
tl
.
load
(
A_ptr
+
ps0
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_0
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v0
=
tl
.
load
(
V_ptr
+
ps0
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_0
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_0
=
tl
.
where
(
inb_0
,
a0
*
tl
.
sqrt
(
tl
.
sum
(
v0
*
v0
)),
0.0
)
pos_p1
=
l
+
1
inb_p1
=
(
pos_p1
>=
0
)
&
(
pos_p1
<
L
)
ps_p1
=
tl
.
where
(
inb_p1
,
pos_p1
,
0
)
a_p1
=
tl
.
load
(
A_ptr
+
ps_p1
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_p1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v_p1
=
tl
.
load
(
V_ptr
+
ps_p1
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_p1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_p1
=
tl
.
where
(
inb_p1
,
a_p1
*
tl
.
sqrt
(
tl
.
sum
(
v_p1
*
v_p1
)),
0.0
)
out
=
(
s_m1
+
s_0
+
s_p1
)
*
(
1.0
/
3.0
)
tl
.
store
(
OUT_ptr
+
l
*
stride_ol
+
h
*
stride_oh
,
out
)
def
_mul_vnorm_avgpool3_fused
(
a
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
assert
a
.
dim
()
==
2
and
v
.
dim
()
==
3
and
a
.
shape
[
0
]
==
v
.
shape
[
0
]
and
a
.
shape
[
1
]
==
v
.
shape
[
1
]
L
,
H
,
D
=
v
.
shape
a
=
a
.
contiguous
()
v
=
v
.
contiguous
()
if
a
.
dtype
!=
torch
.
float32
:
a
=
a
.
float
()
if
out
is
None
:
out
=
torch
.
empty
((
L
,
H
),
device
=
v
.
device
,
dtype
=
torch
.
float32
)
if
L
==
0
or
H
==
0
:
return
out
grid
=
(
L
,
H
)
_mul_vnorm_avgpool3_kernel
[
grid
](
a
,
v
,
out
,
a
.
stride
(
0
),
a
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
out
.
stride
(
0
),
out
.
stride
(
1
),
L
,
D
=
D
,
num_warps
=
4
,
)
return
out
def
_maybe_mul_vnorm_avgpool3_fused
(
a
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
a
.
is_cuda
or
not
v
.
is_cuda
:
import
torch.nn.functional
as
F
s
=
a
*
v
.
norm
(
dim
=-
1
)
return
(
F
.
avg_pool1d
(
s
.
transpose
(
0
,
1
).
unsqueeze
(
0
),
kernel_size
=
3
,
padding
=
1
,
stride
=
1
)
.
squeeze
(
0
)
.
transpose
(
0
,
1
)
)
return
_mul_vnorm_avgpool3_fused
(
a
,
v
)
@
triton
.
jit
def
_zscore_elem_1d_kernel
(
X_ptr
,
OUT_ptr
,
n
,
mean
,
inv_std
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
offs
<
n
x
=
tl
.
load
(
X_ptr
+
offs
,
mask
=
mask
,
other
=
0.0
)
tl
.
store
(
OUT_ptr
+
offs
,
(
x
-
mean
)
*
inv_std
,
mask
=
mask
)
def
_zscore_flat_f32_global
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
"""
if
x
.
numel
()
==
0
:
return
x
mu
=
x
.
mean
()
sig
=
x
.
std
().
clamp_min
(
1e-6
)
inv
=
1.0
/
sig
if
not
x
.
is_cuda
:
return
(
x
-
mu
)
*
inv
x
=
x
.
contiguous
()
out
=
torch
.
empty_like
(
x
)
n
=
x
.
numel
()
BLOCK
=
1024
grid
=
(
triton
.
cdiv
(
n
,
BLOCK
),)
_zscore_elem_1d_kernel
[
grid
](
x
,
out
,
n
,
float
(
mu
.
item
()),
float
(
inv
.
item
()),
BLOCK
=
BLOCK
,
num_warps
=
4
,
)
return
out
def
_attn_scores_kvpress_middle
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
sink_start
:
int
,
sink_end
:
int
,
chunk_size
:
int
,
do_zscore
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
N
,
HQ
,
D
=
q
.
shape
Hkv
=
k
.
shape
[
1
]
G
=
HQ
//
Hkv
device
=
q
.
device
attn_out
=
torch
.
zeros
(
N
,
Hkv
,
device
=
device
,
dtype
=
torch
.
float32
)
parts
:
list
[
torch
.
Tensor
]
=
[]
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
q_m
=
q
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
k_m
=
k
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
v_m
=
v
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
# HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
k_4d
=
k_m
.
unsqueeze
(
0
).
transpose
(
1
,
2
).
contiguous
()
# [1, Hkv, Lm, D]
k_rep
=
repeat_kv
(
k_4d
,
G
)[
0
].
transpose
(
0
,
1
).
contiguous
()
# [Lm, HQ, D]
A
=
non_causal_chunked_attn
(
q_m
,
k_rep
,
chunk_size
)
Lm
,
HQa
=
A
.
shape
assert
HQa
==
HQ
A
=
A
.
view
(
Lm
,
Hkv
,
G
).
mean
(
dim
=-
1
)
scores
=
_maybe_mul_vnorm_avgpool3_fused
(
A
,
v_m
)
parts
.
append
(
scores
.
reshape
(
-
1
))
if
not
parts
:
return
attn_out
flat_a
=
torch
.
cat
(
parts
,
dim
=
0
)
if
do_zscore
:
z_a
=
_zscore_flat_f32_global
(
flat_a
)
else
:
z_a
=
flat_a
offset
=
0
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
n
=
(
mid_end
-
mid_start
)
*
Hkv
attn_out
[
mid_start
:
mid_end
,
:]
=
z_a
[
offset
:
offset
+
n
].
view
(
mid_end
-
mid_start
,
Hkv
)
offset
+=
n
return
attn_out
def
non_causal_attn_scores
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_qk
:
torch
.
Tensor
,
max_seqlen_qk
:
int
,
chunk_size
:
int
,
sm_scale
:
float
=
None
,
normalize
:
bool
=
True
,
context_lens
:
Optional
[
List
[
int
]]
=
None
,
protected_first_tokens
:
Optional
[
List
[
int
]]
=
None
,
protected_last_tokens
:
Optional
[
List
[
int
]]
=
None
,
*
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
)
->
torch
.
Tensor
:
"""
与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
"""
del
sm_scale
,
max_seqlen_qk
sink_start
,
sink_end
=
8
,
4
out
=
_attn_scores_kvpress_middle
(
q
,
k
,
v
,
cu_seqlens_qk
,
sink_start
,
sink_end
,
chunk_size
,
do_zscore
=
normalize
,
)
if
accum_scores
is
not
None
:
w
=
0.5
if
accum_blending
is
None
else
float
(
accum_blending
)
out
=
out
+
w
*
accum_scores
.
to
(
device
=
out
.
device
,
dtype
=
out
.
dtype
)
if
protected_first_tokens
is
not
None
and
protected_last_tokens
is
not
None
and
context_lens
:
start
=
0
for
first
,
last
,
Lc
in
zip
(
protected_first_tokens
,
protected_last_tokens
,
context_lens
):
out
[
start
:
start
+
int
(
first
)].
fill_
(
torch
.
inf
)
out
[
start
+
int
(
Lc
)
-
int
(
last
)
:
start
+
int
(
Lc
)].
fill_
(
torch
.
inf
)
start
+=
int
(
Lc
)
return
out
def
kvpress_compactor_post_rope
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
compression_ctx
,
max_seqlen_q
:
int
,
chunk_size
:
int
,
blending
:
float
,
)
->
torch
.
Tensor
:
del
max_seqlen_q
Hkv
=
k
.
shape
[
1
]
device
=
q
.
device
sink_start
=
int
(
getattr
(
compression_ctx
,
"sink_size_start"
,
8
))
sink_end
=
int
(
getattr
(
compression_ctx
,
"sink_size_end"
,
4
))
context_lens
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"context_lens"
,
None
)
protected_first
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"protected_first_tokens"
,
None
)
protected_last
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"protected_last_tokens"
,
None
)
attn_out
=
_attn_scores_kvpress_middle
(
q
,
k
,
v
,
cu_seqlens
,
sink_start
,
sink_end
,
chunk_size
)
lev
=
pre_rope_scores
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
blended
=
torch
.
zeros_like
(
lev
)
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
blended
[
mid_start
:
mid_end
,
:]
=
(
blending
*
lev
[
mid_start
:
mid_end
,
:]
+
attn_out
[
mid_start
:
mid_end
,
:]
)
pad_val
=
blended
.
max
()
if
not
torch
.
isfinite
(
pad_val
)
or
pad_val
==
0
:
pad_val
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
left_keep
>
0
:
blended
[
k_beg
:
mid_start
,
:]
=
pad_val
if
right_keep
>
0
:
blended
[
mid_end
:
k_end
,
:]
=
pad_val
if
protected_first
is
not
None
and
protected_last
is
not
None
and
context_lens
:
start
=
0
for
first
,
last
,
Lc
in
zip
(
protected_first
,
protected_last
,
context_lens
):
blended
[
start
:
start
+
int
(
first
)].
fill_
(
torch
.
inf
)
blended
[
start
+
int
(
Lc
)
-
int
(
last
)
:
start
+
int
(
Lc
)].
fill_
(
torch
.
inf
)
start
+=
int
(
Lc
)
return
blended
vllm/compactor-vllm/src/compactor_vllm/compression/compactor_origin.py
deleted
100644 → 0
View file @
2b7160c6
import
logging
import
math
from
typing
import
List
,
Optional
import
torch
import
triton
from
tqdm.contrib.logging
import
logging_redirect_tqdm
from
triton
import
language
as
tl
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
logger
=
logging
.
getLogger
(
__name__
)
class
CompactorCompression
(
BaseCompressionMethod
):
chunk_size
:
int
=
128
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
scores
=
maybe_execute_in_stream
(
approximate_leverage_scores
,
k
,
compression_context
.
context_lens
,
compression_context
.
PHI
,
normalize
=
True
,
chunk_size
=
compression_context
.
compression_chunk_size
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
return
maybe_execute_in_stream
(
non_causal_attn_scores
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
sm_scale
=
1.0
,
normalize
=
True
,
accum_scores
=
pre_rope_scores
,
context_lens
=
compression_context
.
context_lens
,
protected_first_tokens
=
compression_context
.
protected_first_tokens
,
protected_last_tokens
=
compression_context
.
protected_last_tokens
,
accum_blending
=
0.5
,
)
def
split_into_chunks
(
xs
,
chunk_size
):
"""
Convert a list of sequence lengths into a sequence of coalesced chunk lengths.
Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
this helper produces two parallel lists:
* ``coalesced_chunks`` – lengths of contiguous segments in the
**concatenated** sequence space, where each segment corresponds either
to a full chunk of size ``chunk_size`` or to a residual "epilogue"
tail shorter than ``chunk_size``.
* ``chunks`` – the actual chunk sizes used within each original sequence.
For a length ``n``, we produce ``n // chunk_size`` entries of
``chunk_size`` (the "prologue") and at most one final entry equal to
``n % chunk_size`` (the "epilogue").
``chunks`` reflects how each input length is decomposed into
fixed-size (plus optional tail) processing blocks, while
``coalesced_chunks`` describes those same blocks after concatenating consecutive
chunks of size ``chunk_size``. together
Example:
xs = [257, 127], chunk_size = 128
coalesced_chunks = [256, 1, 127]
chunks = [128, 128, 1, 127]
Args:
:param xs:
Iterable of non-negative integers
:param chunk_size:
Target chunk size
Returns:
:return Tuple[List[int], List[int]]:
``(coalesced_chunks, chunks)`` as described above.
"""
coalesced_chunks
,
chunks
=
[],
[]
for
n
in
xs
:
nchunks
=
n
//
chunk_size
prologue
=
nchunks
*
chunk_size
epilogue
=
n
-
prologue
if
prologue
>
0
:
coalesced_chunks
.
append
(
prologue
)
chunks
.
extend
([
chunk_size
]
*
nchunks
)
if
epilogue
>
0
:
coalesced_chunks
.
append
(
epilogue
)
chunks
.
append
(
epilogue
)
return
coalesced_chunks
,
chunks
def
approximate_leverage_scores
(
key_states
:
torch
.
Tensor
,
# [N, H, D]
context_lens
:
List
[
int
],
# [B]
PHI
:
torch
.
Tensor
,
# [D, k]
regularizer
:
float
=
5e-3
,
normalize
:
bool
=
False
,
chunk_size
:
int
=
512
,
)
->
torch
.
Tensor
:
# returns [N, H]
"""
Approximate leverage scores for keys via randomized sketching.
This implements a randomized approximation to per-token leverage scores for
the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
Args:
:param key_states:
Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
all tokens across the batch, packed along the sequence dimension.
``N = sum(context_lens)``.
:param context_lens:
List of per-sequence context lengths, length ``B``.
:param PHI:
Random projection matrix of shape ``[D, k]`` used to sketch the
keys into a lower-dimensional subspace (k < D).
:param regularizer:
Small positive scalar added to the diagonal of each Gram matrix
before SVD to improve numerical stability. Defaults to ``1e-2``.
:param normalize:
If True, apply per-sequence z-score normalization to the scores
across all heads and tokens in a batch.
:param chunk_size:
Target chunk size along the sequence dimension. If > 0, the
concatenated sequence is split into chunks of at most this size
before forming Gram matrices and SVD. If ≤ 0, the entire sequence
for each context is treated as a single chunk.
Returns:
:return torch.Tensor:
Approximate leverage scores of shape ``[N, H]``, where each row
corresponds to a token and each column to a head.
"""
if
chunk_size
>
0
:
coalesced_chunk_lens
,
chunks_lens
=
split_into_chunks
(
context_lens
,
chunk_size
)
else
:
coalesced_chunk_lens
,
chunks_lens
=
context_lens
,
context_lens
chunk_lens_cuda
=
torch
.
tensor
([
0
]
+
chunks_lens
).
cuda
(
non_blocking
=
True
)
X
=
torch
.
matmul
(
key_states
.
transpose
(
0
,
1
),
PHI
)
H
,
N
,
k
=
X
.
shape
chunks
=
torch
.
split
(
X
,
coalesced_chunk_lens
,
dim
=-
2
)
gram_matrices
=
[]
for
i
,
L
in
enumerate
(
coalesced_chunk_lens
):
chunk
=
chunks
[
i
]
if
chunk_size
<=
0
or
L
%
chunk_size
!=
0
:
chunk
.
sub_
(
chunk
.
mean
(
dim
=-
2
,
keepdim
=
True
))
g
=
torch
.
matmul
(
chunk
.
transpose
(
-
1
,
-
2
),
chunk
)
# [H, k, k]
g
=
g
.
unsqueeze
(
1
)
else
:
chunk
=
chunk
.
view
(
H
,
-
1
,
chunk_size
,
k
)
# [H, num_chunks, chunk_size, k]
chunk
.
sub_
(
chunk
.
mean
(
dim
=-
2
,
keepdim
=
True
))
g
=
torch
.
matmul
(
chunk
.
transpose
(
-
1
,
-
2
),
chunk
)
# [H, num_chunks, k, k]
gram_matrices
.
append
(
g
)
G
=
torch
.
cat
(
gram_matrices
,
dim
=
1
).
to
(
torch
.
float32
)
diag
=
G
.
diagonal
(
dim1
=-
2
,
dim2
=-
1
)
diag
.
add_
(
regularizer
)
try
:
V
,
S
,
Vt
=
torch
.
linalg
.
svd
(
G
,
full_matrices
=
False
,
driver
=
"gesvda"
)
except
RuntimeError
:
try
:
diag
=
G
.
diagonal
(
dim1
=-
2
,
dim2
=-
1
)
diag
.
add_
(
regularizer
*
10
)
V
,
S
,
Vt
=
torch
.
linalg
.
svd
(
G
,
full_matrices
=
False
,
driver
=
"gesvda"
)
except
RuntimeError
:
with
logging_redirect_tqdm
():
logger
.
warning
(
"GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
"Try increasing chunk_size if this issue persists."
)
# this is over 50 times slower than using GESVDA
return
_approximate_leverage_scores_qr_fallback
(
X
=
X
,
chunks_lens
=
chunks_lens
,
chunk_lens_cuda
=
chunk_lens_cuda
,
normalize
=
normalize
,
chunk_size
=
chunk_size
,
)
SV
=
(
V
*
S
.
rsqrt
().
unsqueeze
(
-
2
)).
to
(
X
.
dtype
)
start
=
0
all_scores
=
[]
for
i
,
L
in
enumerate
(
coalesced_chunk_lens
):
chunk
=
chunks
[
i
]
if
chunk_size
<=
0
or
L
%
chunk_size
!=
0
:
num_chunks
=
1
sv
=
SV
[:,
start
]
else
:
num_chunks
=
L
//
chunk_size
chunk
=
chunk
.
view
(
H
,
-
1
,
chunk_size
,
k
)
# [H, NC, CS]
sv
=
SV
[:,
start
:
start
+
num_chunks
]
U
=
torch
.
matmul
(
chunk
,
sv
)
scores
=
(
U
*
U
).
sum
(
dim
=-
1
).
clamp_min_
(
0.0
).
view
(
H
,
-
1
)
all_scores
.
append
(
scores
.
transpose
(
-
1
,
-
2
))
start
+=
num_chunks
scores
=
torch
.
cat
(
all_scores
,
dim
=
0
)
if
normalize
:
grid
=
(
len
(
chunks_lens
),)
cu_k
=
chunk_lens_cuda
.
cumsum
(
dim
=
0
)
_zscore_per_batch_epilogue_no_window
[
grid
](
scores
,
cu_k
,
scores
.
stride
(
0
),
scores
.
stride
(
1
),
H
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue_no_window
(
OUT
,
# [Nk, Hk], float32
cu_k
,
# [B+1] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
if
k_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
)
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
def
_approximate_leverage_scores_qr_fallback
(
X
:
torch
.
Tensor
,
# [H, N, k], already sketched (KΦ) and centered in-place
chunks_lens
:
List
[
int
],
# [num_chunks]
chunk_lens_cuda
:
torch
.
Tensor
,
# [num_chunks + 1] (prefix base)
normalize
:
bool
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
H
,
N
,
k
=
X
.
shape
device
,
dtype
=
X
.
device
,
X
.
dtype
offsets
:
List
[
int
]
=
[]
offset
=
0
for
L
in
chunks_lens
:
offsets
.
append
(
offset
)
offset
+=
L
if
offset
!=
N
:
raise
RuntimeError
(
f
"QR fallback: sum(chunks_lens)=
{
offset
}
does not match N=
{
N
}
"
)
blocks
=
torch
.
split
(
X
,
chunks_lens
,
dim
=-
2
)
scores
=
torch
.
empty
(
N
,
H
,
device
=
device
,
dtype
=
dtype
)
if
chunk_size
>
0
:
full_indices
=
[
i
for
i
,
L
in
enumerate
(
chunks_lens
)
if
L
==
chunk_size
]
epi_indices
=
[
i
for
i
,
L
in
enumerate
(
chunks_lens
)
if
L
!=
chunk_size
]
if
full_indices
:
# stack full chunks
full_blocks
=
torch
.
stack
(
[
blocks
[
i
]
for
i
in
full_indices
],
dim
=
0
)
# [M, H, CS, k]
M
,
Hf
,
Lf
,
kf
=
full_blocks
.
shape
assert
Lf
==
chunk_size
# merge (M, H) into a single batch dim for torch.linalg.q
full_blocks_2d
=
full_blocks
.
view
(
M
*
Hf
,
Lf
,
kf
).
to
(
torch
.
float32
)
U_full
,
_
=
torch
.
linalg
.
qr
(
full_blocks_2d
,
mode
=
"reduced"
)
U_full
=
U_full
.
to
(
dtype
)
scores_full
=
(
U_full
*
U_full
).
sum
(
dim
=-
1
).
clamp_min
(
0.0
)
# [M * Hf, Lf]
scores_full
=
scores_full
.
view
(
M
,
Hf
,
Lf
).
transpose
(
-
1
,
-
2
)
# [M, H, CS]
for
m
,
chunk_idx
in
enumerate
(
full_indices
):
start
=
offsets
[
chunk_idx
]
Lc
=
chunks_lens
[
chunk_idx
]
scores
[
start
:
start
+
Lc
].
copy_
(
scores_full
[
m
])
else
:
epi_indices
=
list
(
range
(
len
(
chunks_lens
)))
for
chunk_idx
in
epi_indices
:
block
=
blocks
[
chunk_idx
]
_
,
Lc
,
_
=
block
.
shape
if
Lc
==
0
:
continue
U_epi
,
_
=
torch
.
linalg
.
qr
(
block
.
to
(
torch
.
float32
),
mode
=
"reduced"
)
scores_epi
=
(
U_epi
*
U_epi
).
sum
(
dim
=-
1
).
to
(
dtype
)
# [H, Lc]
start
=
offsets
[
chunk_idx
]
scores
[
start
:
start
+
Lc
]
=
scores_epi
.
transpose
(
0
,
1
)
# [Lc, H]
if
normalize
:
grid
=
(
len
(
chunks_lens
),)
cu_k
=
chunk_lens_cuda
.
cumsum
(
dim
=
0
)
_zscore_per_batch_epilogue_no_window
[
grid
](
scores
,
cu_k
,
scores
.
stride
(
0
),
scores
.
stride
(
1
),
H
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
BM
,
"BLOCK_K"
:
BK
,
"WARPSPEC"
:
False
},
num_warps
=
w
,
num_stages
=
s
)
for
BM
in
[
64
]
for
BK
in
[
64
]
for
w
in
[
4
]
for
s
in
[
2
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"CHUNK_SIZE"
,
],
cache_results
=
True
,
)
@
triton
.
jit
def
_non_causal_attn_kernel
(
Q
,
K
,
V
,
accum_scores
,
cu_seqlens_qk
,
#
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
,
STRIDE_K_G
,
STRIDE_K_N
,
STRIDE_K_D
,
STRIDE_V_G
,
STRIDE_V_N
,
STRIDE_V_D
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
#
CHUNK_SIZE
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
):
TOTAL_QUERIES_PER_BLOCK
:
tl
.
constexpr
=
BLOCK_M
*
QUERY_GROUP_SIZE
INVERSE_CHUNK
:
tl
.
constexpr
=
1.0
/
CHUNK_SIZE
pid_g
=
tl
.
program_id
(
0
)
# KV head in [0, HKV)
pid_b
=
tl
.
program_id
(
1
)
# batch id
pid_m
=
tl
.
program_id
(
2
)
# chunk id within batch
off_b
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
)
off_b1
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
+
1
)
chunk_start
=
off_b
+
pid_m
*
CHUNK_SIZE
chunk_end
=
tl
.
minimum
(
chunk_start
+
CHUNK_SIZE
,
off_b1
)
M
=
chunk_end
-
chunk_start
if
M
<=
0
:
return
offs_d
=
tl
.
arange
(
0
,
D
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
# Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
offs_q
=
tl
.
arange
(
0
,
TOTAL_QUERIES_PER_BLOCK
)
row_m
=
offs_q
%
BLOCK_M
# token offset in this tile
row_h
=
offs_q
//
BLOCK_M
# query-group index
qk_scale
=
sm_scale
*
1.44269504
# convert to log2-domain
NEG_INF
=
-
1.0e9
# Iterate over query tiles within this chunk
for
qs
in
tl
.
range
(
chunk_start
,
chunk_end
,
BLOCK_M
):
# Global query indices for rows in this tile
q_idx
=
qs
+
row_m
# [TOTAL_QUERIES_PER_BLOCK]
q_mask
=
q_idx
<
chunk_end
# mask for valid rows in this tile
# Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
q_ptrs
=
(
Q
+
pid_g
*
STRIDE_Q_G
+
q_idx
[:,
None
]
*
STRIDE_Q_N
+
row_h
[:,
None
]
*
STRIDE_Q_H
+
offs_d
[
None
,
:]
*
STRIDE_Q_D
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
q_mask
[:,
None
],
other
=
0.0
)
# ---- Pass 1: per-row max and denominator over all keys in this chunk ----
row_max
=
tl
.
full
([
TOTAL_QUERIES_PER_BLOCK
],
NEG_INF
,
tl
.
float32
)
row_sum
=
tl
.
zeros
([
TOTAL_QUERIES_PER_BLOCK
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
chunk_start
,
chunk_end
,
BLOCK_K
):
k_idx
=
ks
+
offs_k
# [BLOCK_K]
k_mask
=
k_idx
<
chunk_end
# which keys are valid in this tile
k_ptrs
=
(
K
+
pid_g
*
STRIDE_K_G
+
k_idx
[:,
None
]
*
STRIDE_K_N
+
offs_d
[
None
,
:]
*
STRIDE_K_D
)
k
=
tl
.
load
(
k_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
)
# [BLOCK_K, D]
# logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
q_mask
[:,
None
]
&
k_mask
[
None
,
:],
qk
,
NEG_INF
)
cur_max
=
tl
.
max
(
qk
,
1
)
new_max
=
tl
.
maximum
(
row_max
,
cur_max
)
# rescale previous sum to new_max (base 2)
rescale
=
tl
.
math
.
exp2
(
row_max
-
new_max
)
p
=
tl
.
math
.
exp2
(
qk
-
new_max
[:,
None
])
row_sum
=
row_sum
*
rescale
+
tl
.
sum
(
p
,
1
)
row_max
=
new_max
# Avoid division by zero for inactive rows
denom
=
tl
.
where
(
q_mask
,
row_sum
,
1.0
)
for
ks
in
tl
.
range
(
chunk_start
,
chunk_end
,
BLOCK_K
):
k_idx
=
ks
+
offs_k
k_mask
=
k_idx
<
chunk_end
k_ptrs
=
(
K
+
pid_g
*
STRIDE_K_G
+
k_idx
[:,
None
]
*
STRIDE_K_N
+
offs_d
[
None
,
:]
*
STRIDE_K_D
)
k
=
tl
.
load
(
k_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
q_mask
[:,
None
]
&
k_mask
[
None
,
:],
qk
,
NEG_INF
)
# p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
p
=
tl
.
math
.
exp2
(
qk
-
row_max
[:,
None
])
/
denom
[:,
None
]
# zero-out invalid rows / columns
p
=
tl
.
where
(
q_mask
[:,
None
],
p
,
INVERSE_CHUNK
)
# preserve attention mass in shorter chunks
contrib
=
tl
.
sum
(
p
,
0
)
# [BLOCK_K], sum over queries & query-groups
out_ptrs
=
accum_scores
+
k_idx
*
STRIDE_OUT_N
+
pid_g
*
STRIDE_OUT_H
old
=
tl
.
load
(
out_ptrs
,
mask
=
k_mask
,
other
=
0.0
)
new
=
old
+
contrib
.
to
(
old
.
dtype
)
tl
.
store
(
out_ptrs
,
new
,
mask
=
k_mask
)
def
non_causal_attn_scores
(
q
:
torch
.
Tensor
,
# [N, HQ, D]
k
:
torch
.
Tensor
,
# [N, HKV, D]
v
:
torch
.
Tensor
,
# [N, HKV, D]
cu_seqlens_qk
:
torch
.
Tensor
,
# [B + 1]
max_seqlen_qk
:
int
,
chunk_size
:
int
,
sm_scale
:
float
=
None
,
normalize
:
bool
=
True
,
context_lens
:
Optional
[
List
[
int
]]
=
None
,
protected_first_tokens
:
Optional
[
List
[
int
]]
=
None
,
protected_last_tokens
:
Optional
[
List
[
int
]]
=
None
,
*
,
accum_scores
:
torch
.
Tensor
=
None
,
# [N, HKV] (float32)
accum_blending
:
float
=
None
,
)
->
torch
.
Tensor
:
"""
:param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
:param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
:param v: Tensor of shape ``[N, H, D]`` containing values
:param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
:param max_seqlen_qk int containing the maximum sequence length
:param chunk_size: int specifying the size of the chunk to perform non-causal attention over
:param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
:param normalize: bool specifying whether to z-score normalize final attention scores
:param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
:param protected_first_tokens: List[int] specifying how many tokens should be protected at the
start of each sequence
:param protected_last_tokens: List[int] specifying how many tokens should be protected at the
end of each sequence
:param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
:param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
"""
assert
q
.
ndim
==
3
and
k
.
ndim
==
3
assert
q
.
shape
[
0
]
==
k
.
shape
[
0
]
and
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
N
,
HQ
,
D
=
q
.
shape
HKV
=
k
.
shape
[
1
]
assert
HQ
%
HKV
==
0
,
"Number of query heads must divide number of KV heads"
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of two"
B
=
cu_seqlens_qk
.
numel
()
-
1
H_g
=
HQ
//
HKV
# query-group size per KV head
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
out
=
torch
.
zeros
(
N
,
HKV
,
device
=
q
.
device
,
dtype
=
torch
.
float32
)
q
=
q
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
k
=
k
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
# v = v.view(N, HKV, D).permute(1, 0, 2)
if
cu_seqlens_qk
.
device
!=
q
.
device
:
cu_seqlens_qk
=
cu_seqlens_qk
.
to
(
device
=
q
.
device
)
cu_seqlens_qk
=
cu_seqlens_qk
.
to
(
torch
.
int32
)
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
=
q
.
stride
()
STRIDE_K_G
,
STRIDE_K_N
,
STRIDE_K_D
=
k
.
stride
()
STRIDE_V_G
,
STRIDE_V_N
,
STRIDE_V_D
=
v
.
stride
()
STRIDE_OUT_N
,
STRIDE_OUT_H
=
out
.
stride
()
assert
STRIDE_Q_D
==
1
and
STRIDE_K_D
==
1
,
"last dim must be contiguous"
def
grid
(
_
):
return
(
HKV
,
B
,
triton
.
cdiv
(
max_seqlen_qk
,
chunk_size
),
)
_non_causal_attn_kernel
[
grid
](
q
,
k
,
v
,
out
,
cu_seqlens_qk
,
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
,
STRIDE_K_G
,
STRIDE_K_N
,
STRIDE_K_D
,
STRIDE_V_G
,
STRIDE_V_N
,
STRIDE_V_D
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
CHUNK_SIZE
=
chunk_size
,
QUERY_GROUP_SIZE
=
H_g
,
D
=
D
,
)
if
normalize
:
grid
=
(
B
,)
_zscore_per_batch_epilogue_no_window
[
grid
](
out
,
cu_seqlens_qk
,
out
.
stride
(
0
),
out
.
stride
(
1
),
HKV
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
out
+=
accum_scores
*
accum_blending
else
:
out
+=
accum_scores
if
protected_first_tokens
is
not
None
or
protected_last_tokens
is
not
None
:
start
=
0
for
first
,
last
,
L
in
zip
(
protected_first_tokens
,
protected_last_tokens
,
context_lens
):
out
[
start
:
start
+
first
].
fill_
(
torch
.
inf
)
out
[
start
+
L
-
last
:
start
+
L
].
fill_
(
torch
.
inf
)
start
+=
L
return
out
vllm/compactor-vllm/src/compactor_vllm/compression/compression_config.py
deleted
100644 → 0
View file @
2b7160c6
import
logging
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
logger
=
logging
.
getLogger
(
__name__
)
class
CompressionMethod
(
Enum
):
CRITICALADAKV
=
auto
()
COMPACTOR
=
auto
()
SNAPKV
=
auto
()
NONE
=
auto
()
# class CachingPolicy(Enum):
# CACHE_PROMPT = auto()
# DONT_CACHE = auto()
# class CompressionType(Enum):
# QUERY_AWARE = auto()
# QUERY_AGNOSTIC = auto()
@
dataclass
class
SequenceCompressionParams
:
compression_ratio
:
float
=
1.0
protected_first_tokens
:
int
=
16
protected_last_tokens
:
int
=
64
@
dataclass
class
BatchCompressionParams
:
# compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
compression_method
:
CompressionMethod
=
CompressionMethod
.
COMPACTOR
do_chunked_compression
:
bool
=
True
chunk_size
:
int
=
512
def
__post_init__
(
self
):
if
self
.
compression_method
==
CompressionMethod
.
SNAPKV
:
self
.
do_chunked_compression
=
False
logger
.
warning
(
"CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
)
vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv-cursor.py
deleted
100644 → 0
View file @
2b7160c6
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)及首/尾保护段长度。
注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Tuple
import
torch
import
triton
from
triton
import
language
as
tl
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.compression.compactor
import
(
CompactorCompression
,
non_causal_attn_scores
,
)
from
compactor_vllm.compression.snapkv
import
SnapKVCompression
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
# ============================================================================
# Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数)
# ============================================================================
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
,
"BLOCK_D"
:
bd
},
num_warps
=
nw
,
num_stages
=
ns
)
for
bk
in
[
32
,
64
,
128
]
for
bd
in
[
32
,
64
]
for
nw
in
[
4
,
8
]
for
ns
in
[
3
,
4
]
],
key
=
[
"Hk"
,
"D"
,
"HIDDEN"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_compute_wo_v_l1_kernel
(
V
,
WO
,
cu_k
,
OUT
,
STRIDE_V_NK
,
STRIDE_V_HK
,
STRIDE_V_D
,
STRIDE_WO_HQ
,
STRIDE_WO_D
,
STRIDE_WO_HID
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
Hk
:
tl
.
constexpr
,
Hq
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
HIDDEN
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
ks
=
tl
.
program_id
(
2
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
nk_off
=
ks
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
nk
=
k_beg
+
nk_off
k_mask
=
nk
<
k_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
l1_sum
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
g
in
range
(
QUERY_GROUP_SIZE
):
hq
=
hk
*
QUERY_GROUP_SIZE
+
g
v_ptrs
=
(
V
+
nk
[:,
None
]
*
STRIDE_V_NK
+
hk
*
STRIDE_V_HK
+
tl
.
arange
(
0
,
D
)[
None
,
:]
*
STRIDE_V_D
)
v_blk
=
tl
.
load
(
v_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
).
to
(
tl
.
float32
)
for
hid_off
in
range
(
0
,
HIDDEN
,
BLOCK_D
):
hid_idx
=
hid_off
+
tl
.
arange
(
0
,
BLOCK_D
)
hid_mask
=
hid_idx
<
HIDDEN
wo_ptrs
=
(
WO
+
hq
*
STRIDE_WO_HQ
+
tl
.
arange
(
0
,
D
)[:,
None
]
*
STRIDE_WO_D
+
hid_idx
[
None
,
:]
*
STRIDE_WO_HID
)
wo_tile
=
tl
.
load
(
wo_ptrs
,
mask
=
hid_mask
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
wov_tile
=
tl
.
dot
(
v_blk
,
wo_tile
)
l1_sum
+=
tl
.
sum
(
tl
.
abs
(
wov_tile
),
axis
=
1
)
l1_sum
=
l1_sum
/
QUERY_GROUP_SIZE
tl
.
store
(
out_ptrs
,
l1_sum
,
mask
=
k_mask
)
# ============================================================================
# Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合
# ============================================================================
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
,
256
]],
key
=
[
"Hk"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_critical_ada_fuse_kernel
(
BASE_SCORES
,
WO_V_NORM
,
STAGE1_MASK
,
cu_k
,
OUT
,
EPSILON
:
tl
.
constexpr
,
STRIDE_BS_NK
,
STRIDE_BS_HK
,
STRIDE_WN_NK
,
STRIDE_WN_HK
,
STRIDE_S1_NK
,
STRIDE_S1_HK
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
Hk
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
bs_ptrs
=
BASE_SCORES
+
nk
*
STRIDE_BS_NK
+
hk
*
STRIDE_BS_HK
wn_ptrs
=
WO_V_NORM
+
nk
*
STRIDE_WN_NK
+
hk
*
STRIDE_WN_HK
s1_ptrs
=
STAGE1_MASK
+
nk
*
STRIDE_S1_NK
+
hk
*
STRIDE_S1_HK
base
=
tl
.
load
(
bs_ptrs
,
mask
=
kmask
,
other
=
0.0
)
wnorm
=
tl
.
load
(
wn_ptrs
,
mask
=
kmask
,
other
=
1.0
)
stage1_protect
=
tl
.
load
(
s1_ptrs
,
mask
=
kmask
,
other
=
0
).
to
(
tl
.
int32
)
fused
=
(
base
+
EPSILON
)
*
wnorm
fused
=
tl
.
where
(
stage1_protect
==
1
,
float
(
"inf"
),
fused
)
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
fused
,
mask
=
kmask
)
def
critical_ada_key_scores
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
wo_weight
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
base_scores
:
torch
.
Tensor
,
compression_ctx
:
Any
,
*
,
store_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义:
1) alpha_safeguard 安全预算(每头至少保留一部分);
2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets);
3) Stage-1 按 head_budgets * first_stage_ratio 保护;
4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``;
可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、
``critical_ada_alpha_safeguard``。
"""
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
and
v
.
stride
(
-
1
)
==
1
device
=
q
.
device
_
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
D
==
Dk
and
Hq
%
Hk
==
0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B
=
cu_seqlens
.
numel
()
-
1
G
=
Hq
//
Hk
k_lengths
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
btr
=
compression_ctx
.
batch_tokens_to_retain
assert
btr
is
not
None
and
btr
.
numel
()
==
B
btr
=
btr
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
prot_first
=
compression_ctx
.
protected_first_tokens
or
[
0
]
*
B
prot_last
=
compression_ctx
.
protected_last_tokens
or
[
0
]
*
B
epsilon
=
compression_ctx
.
critical_ada_epsilon
first_stage_ratio
=
compression_ctx
.
critical_ada_first_stage_ratio
alpha_safeguard
=
float
(
getattr
(
compression_ctx
,
"critical_ada_alpha_safeguard"
,
0.2
))
alpha_safeguard
=
max
(
0.0
,
min
(
1.0
,
alpha_safeguard
))
if
wo_weight
.
dim
()
==
2
:
hidden_size
,
_
=
wo_weight
.
shape
wo
=
wo_weight
.
transpose
(
0
,
1
).
view
(
Hq
,
D
,
hidden_size
).
contiguous
()
else
:
wo
=
wo_weight
.
contiguous
()
hidden_size
=
wo
.
size
(
-
1
)
wo_v_norm
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
def
grid_wo
(
META
):
max_k_len
=
int
(
k_lengths
.
max
().
item
())
return
(
B
,
Hk
,
triton
.
cdiv
(
max_k_len
,
META
[
"BLOCK_K"
]))
_compute_wo_v_l1_kernel
[
grid_wo
](
v
,
wo
,
cu_seqlens
,
wo_v_norm
,
*
v
.
stride
(),
*
wo
.
stride
(),
*
wo_v_norm
.
stride
(),
Hk
=
Hk
,
Hq
=
Hq
,
D
=
D
,
HIDDEN
=
hidden_size
,
QUERY_GROUP_SIZE
=
G
,
)
stage1_mask
=
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
int32
,
device
=
device
)
# kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。
head_budgets_by_batch
=
[]
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
if
k_len
==
0
:
head_budgets_by_batch
.
append
(
None
)
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
s
=
int
(
prot_first
[
b
])
if
b
<
len
(
prot_first
)
else
0
e
=
int
(
prot_last
[
b
])
if
b
<
len
(
prot_last
)
else
0
lo
,
hi
=
k_beg
+
s
,
k_end
-
e
compressible
=
max
(
0
,
hi
-
lo
)
keep_pairs
=
int
(
btr
[
b
].
item
())
if
compressible
<=
0
:
head_budgets_by_batch
.
append
(
None
)
continue
# 每头 token 预算(kvpress 的 n_kept)
n_kept_tokens
=
max
(
1
,
keep_pairs
//
Hk
)
n_kept_tokens
=
min
(
n_kept_tokens
,
compressible
)
# 安全预算(每头至少保留 n_safe)
n_safe
=
int
(
n_kept_tokens
*
alpha_safeguard
)
if
n_safe
>
0
:
tk_safe
=
min
(
n_safe
,
compressible
)
for
hk
in
range
(
Hk
):
safe_idx
=
torch
.
topk
(
base_scores
[
lo
:
hi
,
hk
],
tk_safe
,
sorted
=
False
).
indices
stage1_mask
[
lo
+
safe_idx
,
hk
]
=
1
# 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算
budget_scores
=
base_scores
[
lo
:
hi
,
:].
clone
()
if
n_safe
>
0
:
budget_scores
[
stage1_mask
[
lo
:
hi
,
:]
==
1
]
=
float
(
"inf"
)
top_pairs
=
min
(
n_kept_tokens
*
Hk
,
budget_scores
.
numel
())
if
top_pairs
<=
0
:
head_budgets_by_batch
.
append
(
None
)
continue
top_idx_flat
=
torch
.
topk
(
budget_scores
.
reshape
(
-
1
),
top_pairs
,
sorted
=
False
).
indices
top_head_idx
=
top_idx_flat
%
Hk
head_budgets
=
torch
.
bincount
(
top_head_idx
,
minlength
=
Hk
).
to
(
torch
.
int32
)
head_budgets_by_batch
.
append
(
head_budgets
)
# Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义)
for
hk
in
range
(
Hk
):
phase1_budget
=
int
(
head_budgets
[
hk
].
item
()
*
first_stage_ratio
)
if
phase1_budget
<=
0
:
continue
tk
=
min
(
phase1_budget
,
compressible
)
top_idx
=
torch
.
topk
(
base_scores
[
lo
:
hi
,
hk
],
tk
,
sorted
=
False
).
indices
stage1_mask
[
lo
+
top_idx
,
hk
]
=
1
final_scores
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
def
grid_fuse
(
_META
):
return
(
B
,
Hk
)
_critical_ada_fuse_kernel
[
grid_fuse
](
base_scores
,
wo_v_norm
,
stage1_mask
,
cu_seqlens
,
final_scores
,
EPSILON
=
epsilon
,
*
base_scores
.
stride
(),
*
wo_v_norm
.
stride
(),
*
stage1_mask
.
stride
(),
*
final_scores
.
stride
(),
Hk
=
Hk
,
)
# Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。
for
b
in
range
(
B
):
hb
=
head_budgets_by_batch
[
b
]
if
hb
is
None
:
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
s
=
int
(
prot_first
[
b
])
if
b
<
len
(
prot_first
)
else
0
e
=
int
(
prot_last
[
b
])
if
b
<
len
(
prot_last
)
else
0
lo
,
hi
=
k_beg
+
s
,
k_end
-
e
if
hi
<=
lo
:
continue
region_len
=
hi
-
lo
for
hk
in
range
(
Hk
):
budget
=
int
(
hb
[
hk
].
item
())
if
budget
<=
0
:
continue
tk
=
min
(
budget
,
region_len
)
idx
=
torch
.
topk
(
final_scores
[
lo
:
hi
,
hk
],
tk
,
sorted
=
False
).
indices
final_scores
[
lo
+
idx
,
hk
]
=
float
(
"inf"
)
masked_key_indices
=
None
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
if
k_len
==
0
:
continue
keep_pairs
=
int
(
btr
[
b
].
item
())
total_pairs
=
k_len
*
Hk
if
keep_pairs
>=
total_pairs
:
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
n_prune_pairs
=
min
(
total_pairs
-
keep_pairs
,
total_pairs
)
if
n_prune_pairs
<=
0
:
continue
flat_scores
=
final_scores
[
k_beg
:
k_end
,
:].
reshape
(
-
1
)
prune_idx
=
torch
.
topk
(
-
flat_scores
,
min
(
n_prune_pairs
,
flat_scores
.
numel
()),
sorted
=
False
).
indices
batch_idx
=
torch
.
full_like
(
prune_idx
,
b
,
dtype
=
torch
.
int64
)
head_idx
=
prune_idx
%
Hk
seq_idx
=
prune_idx
//
Hk
+
k_beg
if
masked_key_indices
is
None
:
masked_key_indices
=
(
batch_idx
,
head_idx
,
seq_idx
)
else
:
masked_key_indices
=
(
torch
.
cat
([
masked_key_indices
[
0
],
batch_idx
]),
torch
.
cat
([
masked_key_indices
[
1
],
head_idx
]),
torch
.
cat
([
masked_key_indices
[
2
],
seq_idx
]),
)
if
store_stream
is
not
None
:
final_scores
.
record_stream
(
store_stream
)
return
final_scores
,
masked_key_indices
class
CriticalAdaKVCompression
(
BaseCompressionMethod
):
"""
以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
cc
=
context
.
compression_context
base
=
getattr
(
cc
,
"critical_ada_base_scorer"
,
"compactor"
)
if
cc
is
not
None
else
"compactor"
if
str
(
base
).
lower
()
==
"snapkv"
:
return
SnapKVCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
return
CompactorCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
Optional
[
torch
.
Tensor
],
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
assert
compression_context
is
not
None
base
=
str
(
getattr
(
compression_context
,
"critical_ada_base_scorer"
,
"compactor"
)).
lower
()
if
base
==
"snapkv"
:
base_scores
=
SnapKVCompression
.
post_rope_scoring
(
q
,
k
,
v
,
pre_rope_scores
,
context
)
else
:
# 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
# maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
# 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
if
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
base_scores
=
maybe_execute_in_stream
(
non_causal_attn_scores
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
sm_scale
=
1.0
,
normalize
=
True
,
accum_scores
=
pre_rope_scores
,
context_lens
=
compression_context
.
context_lens
,
protected_first_tokens
=
compression_context
.
protected_first_tokens
,
protected_last_tokens
=
compression_context
.
protected_last_tokens
,
accum_blending
=
0.5
,
)
wo_weight
=
compression_context
.
wo_weight
if
wo_weight
is
None
:
return
base_scores
scores
,
_masked
=
maybe_execute_in_stream
(
critical_ada_key_scores
,
q
,
k
,
v
,
wo_weight
,
context
.
cu_seqlens_q
,
base_scores
,
compression_context
,
STORE_STREAM
=
context
.
STORE_STREAM
,
store_stream
=
context
.
STORE_STREAM
,
)
return
scores
@
staticmethod
def
prepare_layer
(
module
:
torch
.
nn
.
Module
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if
not
hasattr
(
module
,
"o_proj"
)
or
module
.
o_proj
.
weight
is
None
:
return
if
not
hasattr
(
module
,
"num_heads"
)
or
not
hasattr
(
module
,
"head_dim"
):
return
wo_raw
=
module
.
o_proj
.
weight
.
data
hidden_size
,
_
=
wo_raw
.
shape
Hq
=
module
.
num_heads
head_dim
=
module
.
head_dim
wo
=
(
wo_raw
.
transpose
(
0
,
1
)
.
view
(
Hq
,
head_dim
,
hidden_size
)
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
)
module
.
_critical_ada_wo_weight
=
wo
vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv.py
deleted
100644 → 0
View file @
2b7160c6
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)。CriticalAda 主链在 **PyTorch** 中与 kvpress ``CriticalAdaKVPress.compress``
对齐;``||Wo@V||_1`` 仍默认用 Triton ``_compute_wo_v_l1_kernel``(与 ``CriticalKVPress.vwl1norm`` 同式)。
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 可改走 ``_vwl1_norm_kvpress_reference``。
注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Tuple
import
torch
import
triton
from
triton
import
language
as
tl
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.compression.compactor
import
(
CompactorCompression
,
kvpress_compactor_post_rope
,
resolve_kvpress_compactor_blending
,
)
from
compactor_vllm.compression.snapkv
import
SnapKVCompression
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
_USE_WO_L1_REFERENCE_BACKEND
=
False
def
_vwl1_norm_kvpress_reference
(
values_seg
:
torch
.
Tensor
,
wo
:
torch
.
Tensor
,
num_kv_heads
:
int
,
num_query_groups
:
int
,
)
->
torch
.
Tensor
:
"""
与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
"""
k_len
,
Hk
,
D
=
values_seg
.
shape
Hq
,
D_wo
,
hidden
=
wo
.
shape
assert
D
==
D_wo
and
Hk
==
num_kv_heads
and
Hq
==
Hk
*
num_query_groups
# [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
v_4d
=
values_seg
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
v_rep
=
repeat_kv
(
v_4d
,
num_query_groups
)
# [1, Hq, k_len, D]
# Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
wo_f
=
wo
head_list
=
[]
for
head
in
range
(
Hq
):
v_h
=
v_rep
[
0
,
head
,
:,
:].
to
(
dtype
=
wo_f
.
dtype
)
head_wov
=
v_h
.
matmul
(
wo_f
[
head
,
:,
:])
head_wov_norm
=
torch
.
norm
(
head_wov
,
p
=
1
,
dim
=-
1
)
head_list
.
append
(
head_wov_norm
)
stacked
=
torch
.
stack
(
head_list
,
dim
=
0
)
# [Hq, k_len]
stacked
=
stacked
.
view
(
Hk
,
num_query_groups
,
k_len
).
mean
(
dim
=
1
)
return
stacked
.
transpose
(
0
,
1
).
contiguous
()
# ============================================================================
# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
# ============================================================================
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
,
"BLOCK_D"
:
bd
},
num_warps
=
nw
,
num_stages
=
ns
)
for
bk
in
[
32
,
64
,
128
]
for
bd
in
[
32
,
64
]
for
nw
in
[
4
,
8
]
for
ns
in
[
3
,
4
]
],
key
=
[
"Hk"
,
"D"
,
"HIDDEN"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_compute_wo_v_l1_kernel
(
V
,
WO
,
cu_k
,
OUT
,
STRIDE_V_NK
,
STRIDE_V_HK
,
STRIDE_V_D
,
STRIDE_WO_HQ
,
STRIDE_WO_D
,
STRIDE_WO_HID
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
Hk
:
tl
.
constexpr
,
Hq
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
HIDDEN
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
"""对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
ks
=
tl
.
program_id
(
2
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
nk_off
=
ks
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
nk
=
k_beg
+
nk_off
k_mask
=
nk
<
k_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
l1_sum
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
g
in
range
(
QUERY_GROUP_SIZE
):
hq
=
hk
*
QUERY_GROUP_SIZE
+
g
v_ptrs
=
(
V
+
nk
[:,
None
]
*
STRIDE_V_NK
+
hk
*
STRIDE_V_HK
+
tl
.
arange
(
0
,
D
)[
None
,
:]
*
STRIDE_V_D
)
v_blk
=
tl
.
load
(
v_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
).
to
(
tl
.
float32
)
for
hid_off
in
range
(
0
,
HIDDEN
,
BLOCK_D
):
hid_idx
=
hid_off
+
tl
.
arange
(
0
,
BLOCK_D
)
hid_mask
=
hid_idx
<
HIDDEN
wo_ptrs
=
(
WO
+
hq
*
STRIDE_WO_HQ
+
tl
.
arange
(
0
,
D
)[:,
None
]
*
STRIDE_WO_D
+
hid_idx
[
None
,
:]
*
STRIDE_WO_HID
)
wo_tile
=
tl
.
load
(
wo_ptrs
,
mask
=
hid_mask
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
wov_tile
=
tl
.
dot
(
v_blk
,
wo_tile
)
l1_sum
+=
tl
.
sum
(
tl
.
abs
(
wov_tile
),
axis
=
1
)
l1_sum
=
l1_sum
/
QUERY_GROUP_SIZE
tl
.
store
(
out_ptrs
,
l1_sum
,
mask
=
k_mask
)
def
critical_ada_key_scores
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
wo_weight
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
base_scores
:
torch
.
Tensor
,
compression_ctx
:
Any
,
*
,
store_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
按 kvpress ``CriticalAdaKVPress.compress`` 的顺序实现:safeguard scatter →
head-major 展平做 head_budgets → Stage1 在 **已抬高** 的分数上 top-k →
``(scores + ε) * ||WoV||₁`` → Stage2 scatter → 最终按 head-major 展平做 bottom-k。
``||Wo@V||₁`` 仍用 Triton(``_compute_wo_v_l1_kernel``);中间 CriticalAda 步骤用 PyTorch
与 kvpress 逐句对齐。仅 base 分数来自 Compactor/SnapKV。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
"""
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
and
v
.
stride
(
-
1
)
==
1
device
=
q
.
device
_
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
D
==
Dk
and
Hq
%
Hk
==
0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B
=
cu_seqlens
.
numel
()
-
1
G
=
Hq
//
Hk
k_lengths
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
btr
=
compression_ctx
.
batch_tokens_to_retain
assert
btr
is
not
None
and
btr
.
numel
()
==
B
btr
=
btr
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
epsilon
=
compression_ctx
.
critical_ada_epsilon
first_stage_ratio
=
compression_ctx
.
critical_ada_first_stage_ratio
alpha_safeguard
=
float
(
compression_ctx
.
critical_ada_alpha_safeguard
)
alpha_safeguard
=
max
(
0.0
,
min
(
1.0
,
alpha_safeguard
))
if
wo_weight
.
dim
()
==
2
:
hidden_size
,
_
=
wo_weight
.
shape
wo
=
wo_weight
.
transpose
(
0
,
1
).
view
(
Hq
,
D
,
hidden_size
).
contiguous
()
else
:
wo
=
wo_weight
.
contiguous
()
hidden_size
=
wo
.
size
(
-
1
)
wo_v_norm
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
if
B
>
0
and
int
(
k_lengths
.
max
().
item
())
>
0
:
if
_USE_WO_L1_REFERENCE_BACKEND
:
for
b
in
range
(
B
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
if
k_end
<=
k_beg
:
continue
v_seg
=
v
[
k_beg
:
k_end
,
:,
:].
contiguous
()
wo_v_norm
[
k_beg
:
k_end
,
:]
=
_vwl1_norm_kvpress_reference
(
v_seg
,
wo
,
Hk
,
G
)
else
:
def
grid_wo
(
META
):
max_k_len
=
int
(
k_lengths
.
max
().
item
())
return
(
B
,
Hk
,
triton
.
cdiv
(
max_k_len
,
META
[
"BLOCK_K"
]))
_compute_wo_v_l1_kernel
[
grid_wo
](
v
,
wo
,
cu_seqlens
,
wo_v_norm
,
*
v
.
stride
(),
*
wo
.
stride
(),
*
wo_v_norm
.
stride
(),
Hk
=
Hk
,
Hq
=
Hq
,
D
=
D
,
HIDDEN
=
hidden_size
,
QUERY_GROUP_SIZE
=
G
,
)
# kvpress 用 finfo.max 抬高分数;与 inf 混用时 topk 行为一致
_score_max
=
float
(
torch
.
finfo
(
torch
.
float32
).
max
)
final_scores
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
head_budgets_by_batch
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
if
k_len
==
0
:
head_budgets_by_batch
.
append
(
None
)
continue
scores_seg
=
base_scores
[
k_beg
:
k_end
,
:].
float
()
keep_pairs
=
int
(
btr
[
b
].
item
())
n_kept_tokens
=
max
(
1
,
keep_pairs
//
Hk
)
n_kept_tokens
=
min
(
n_kept_tokens
,
k_len
)
# scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
scores_work
=
scores_seg
.
clone
()
# --- Alpha safeguard(kvpress L148–152)---
n_safe
=
int
(
n_kept_tokens
*
alpha_safeguard
)
nk
=
min
(
n_safe
,
k_len
)
if
n_safe
>
0
else
0
if
nk
>
0
:
for
hk
in
range
(
Hk
):
top_idx
=
torch
.
topk
(
scores_work
[:,
hk
],
nk
,
dim
=
0
,
largest
=
True
).
indices
scores_work
[
top_idx
,
hk
]
=
_score_max
# --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
top_pairs
=
min
(
n_kept_tokens
*
Hk
,
k_len
*
Hk
)
if
top_pairs
<=
0
:
head_budgets_by_batch
.
append
(
None
)
wn
=
wo_v_norm
[
k_beg
:
k_end
,
:]
final_scores
[
k_beg
:
k_end
,
:]
=
(
scores_seg
+
epsilon
)
*
wn
continue
budget_flat
=
scores_work
.
permute
(
1
,
0
).
contiguous
().
reshape
(
-
1
)
top_idx_flat
=
torch
.
topk
(
budget_flat
,
top_pairs
,
largest
=
True
,
sorted
=
False
).
indices
top_head_idx
=
top_idx_flat
//
k_len
head_budgets
=
torch
.
bincount
(
top_head_idx
,
minlength
=
Hk
).
to
(
torch
.
int64
)
head_budgets_by_batch
.
append
(
head_budgets
)
# --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
head_selection_budget_1st
=
(
(
head_budgets
.
to
(
torch
.
float32
)
*
float
(
first_stage_ratio
))
.
to
(
torch
.
int64
)
.
tolist
()
)
M1
=
max
(
head_selection_budget_1st
)
if
head_selection_budget_1st
else
0
mk
=
min
(
M1
,
k_len
)
if
M1
>
0
else
0
if
mk
>
0
:
top_k_index
=
torch
.
topk
(
scores_work
,
mk
,
dim
=
0
,
largest
=
True
,
sorted
=
True
).
indices
for
hk
in
range
(
Hk
):
phase1_budget
=
int
(
head_selection_budget_1st
[
hk
])
if
phase1_budget
<=
0
:
continue
take
=
min
(
phase1_budget
,
mk
)
scores_work
[
top_k_index
[:
take
,
hk
],
hk
]
=
_score_max
# --- Stage 2 重加权(kvpress L173–175)---
wn
=
wo_v_norm
[
k_beg
:
k_end
,
:]
scores_fused
=
(
scores_work
+
epsilon
)
*
wn
# --- Stage 2 scatter(kvpress L176–179)---
M2
=
int
(
head_budgets
.
max
().
item
())
mk2
=
min
(
M2
,
k_len
)
if
M2
>
0
else
0
if
mk2
>
0
:
top_k_index2
=
torch
.
topk
(
scores_fused
,
mk2
,
dim
=
0
,
largest
=
True
,
sorted
=
True
).
indices
for
hk
in
range
(
Hk
):
budget
=
int
(
head_budgets
[
hk
].
item
())
if
budget
<=
0
:
continue
take
=
min
(
budget
,
mk2
)
scores_fused
[
top_k_index2
[:
take
,
hk
],
hk
]
=
_score_max
final_scores
[
k_beg
:
k_end
,
:]
=
scores_fused
masked_key_indices
=
None
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
if
k_len
==
0
:
continue
keep_pairs
=
int
(
btr
[
b
].
item
())
total_pairs
=
k_len
*
Hk
if
keep_pairs
>=
total_pairs
:
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
n_prune_pairs
=
min
(
total_pairs
-
keep_pairs
,
total_pairs
)
if
n_prune_pairs
<=
0
:
continue
# kvpress L187:``scores.reshape(bsz, -1)`` 即 [H, K] 按 head-major 展平(flat = h*K + t)
flat_scores
=
(
final_scores
[
k_beg
:
k_end
,
:].
permute
(
1
,
0
).
contiguous
().
reshape
(
-
1
)
)
prune_idx
=
torch
.
topk
(
-
flat_scores
,
min
(
n_prune_pairs
,
flat_scores
.
numel
()),
sorted
=
False
).
indices
batch_idx
=
torch
.
full_like
(
prune_idx
,
b
,
dtype
=
torch
.
int64
)
head_idx
=
prune_idx
//
k_len
seq_idx
=
prune_idx
%
k_len
+
k_beg
if
masked_key_indices
is
None
:
masked_key_indices
=
(
batch_idx
,
head_idx
,
seq_idx
)
else
:
masked_key_indices
=
(
torch
.
cat
([
masked_key_indices
[
0
],
batch_idx
]),
torch
.
cat
([
masked_key_indices
[
1
],
head_idx
]),
torch
.
cat
([
masked_key_indices
[
2
],
seq_idx
]),
)
if
store_stream
is
not
None
:
final_scores
.
record_stream
(
store_stream
)
return
final_scores
,
masked_key_indices
class
CriticalAdaKVCompression
(
BaseCompressionMethod
):
"""
仅 ``critical_ada_base_scorer == "compactor"`` 时与 kvpress ``CompactorPress.score`` 一致
(``kvpress_compactor_post_rope``:``blending * l_scores + attn_scores``);其它 base(如 SnapKV)
走对应单一 ScorerPress,再叠 CriticalAda。须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
cc
=
context
.
compression_context
base
=
(
getattr
(
cc
,
"critical_ada_base_scorer"
,
"compactor"
)
if
cc
is
not
None
else
"compactor"
)
if
str
(
base
).
lower
()
==
"compactor"
:
return
CompactorCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
return
SnapKVCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
Optional
[
torch
.
Tensor
],
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
assert
compression_context
is
not
None
base
=
str
(
getattr
(
compression_context
,
"critical_ada_base_scorer"
,
"compactor"
)).
lower
()
if
base
==
"compactor"
:
# 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
if
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
blending
=
resolve_kvpress_compactor_blending
(
compression_context
)
base_scores
=
maybe_execute_in_stream
(
kvpress_compactor_post_rope
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
pre_rope_scores
,
compression_context
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
blending
=
float
(
blending
),
STORE_STREAM
=
context
.
STORE_STREAM
,
)
else
:
base_scores
=
SnapKVCompression
.
post_rope_scoring
(
q
,
k
,
v
,
pre_rope_scores
,
context
)
wo_weight
=
compression_context
.
wo_weight
if
wo_weight
is
None
:
return
base_scores
scores
,
_masked
=
maybe_execute_in_stream
(
critical_ada_key_scores
,
q
,
k
,
v
,
wo_weight
,
context
.
cu_seqlens_q
,
base_scores
,
compression_context
,
STORE_STREAM
=
context
.
STORE_STREAM
,
store_stream
=
context
.
STORE_STREAM
,
)
return
scores
@
staticmethod
def
prepare_layer
(
module
:
torch
.
nn
.
Module
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if
not
hasattr
(
module
,
"o_proj"
)
or
module
.
o_proj
.
weight
is
None
:
return
if
not
hasattr
(
module
,
"num_heads"
)
or
not
hasattr
(
module
,
"head_dim"
):
return
wo_raw
=
module
.
o_proj
.
weight
.
data
hidden_size
,
_
=
wo_raw
.
shape
Hq
=
module
.
num_heads
head_dim
=
module
.
head_dim
wo
=
(
wo_raw
.
transpose
(
0
,
1
)
.
view
(
Hq
,
head_dim
,
hidden_size
)
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
)
module
.
_critical_ada_wo_weight
=
wo
vllm/compactor-vllm/src/compactor_vllm/compression/criticalkv_origin.py
deleted
100644 → 0
View file @
2b7160c6
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)。Stage1/2 与 kvpress 论文/实现一致;``||Wo@V||_1`` 在 **算法上** 与
``CriticalKVPress.vwl1norm`` 相同(GQA 上逐 query 头 L1 再对组取均值)。**默认用 Triton**
(``_compute_wo_v_l1_kernel``);若需与 PyTorch 逐行对齐,将模块内 ``_USE_WO_L1_REFERENCE_BACKEND`` 改为 ``True`` 即走 ``_vwl1_norm_kvpress_reference``。
注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Tuple
import
torch
import
triton
from
triton
import
language
as
tl
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.compression.compactor
import
(
CompactorCompression
,
non_causal_attn_scores
,
)
from
compactor_vllm.compression.snapkv
import
SnapKVCompression
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
_USE_WO_L1_REFERENCE_BACKEND
=
False
def
_vwl1_norm_kvpress_reference
(
values_seg
:
torch
.
Tensor
,
wo
:
torch
.
Tensor
,
num_kv_heads
:
int
,
num_query_groups
:
int
,
)
->
torch
.
Tensor
:
"""
与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
"""
k_len
,
Hk
,
D
=
values_seg
.
shape
Hq
,
D_wo
,
hidden
=
wo
.
shape
assert
D
==
D_wo
and
Hk
==
num_kv_heads
and
Hq
==
Hk
*
num_query_groups
# [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
v_4d
=
values_seg
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
v_rep
=
repeat_kv
(
v_4d
,
num_query_groups
)
# [1, Hq, k_len, D]
# Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
wo_f
=
wo
head_list
=
[]
for
head
in
range
(
Hq
):
v_h
=
v_rep
[
0
,
head
,
:,
:].
to
(
dtype
=
wo_f
.
dtype
)
head_wov
=
v_h
.
matmul
(
wo_f
[
head
,
:,
:])
head_wov_norm
=
torch
.
norm
(
head_wov
,
p
=
1
,
dim
=-
1
)
head_list
.
append
(
head_wov_norm
)
stacked
=
torch
.
stack
(
head_list
,
dim
=
0
)
# [Hq, k_len]
stacked
=
stacked
.
view
(
Hk
,
num_query_groups
,
k_len
).
mean
(
dim
=
1
)
return
stacked
.
transpose
(
0
,
1
).
contiguous
()
# ============================================================================
# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
# ============================================================================
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
,
"BLOCK_D"
:
bd
},
num_warps
=
nw
,
num_stages
=
ns
)
for
bk
in
[
32
,
64
,
128
]
for
bd
in
[
32
,
64
]
for
nw
in
[
4
,
8
]
for
ns
in
[
3
,
4
]
],
key
=
[
"Hk"
,
"D"
,
"HIDDEN"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_compute_wo_v_l1_kernel
(
V
,
WO
,
cu_k
,
OUT
,
STRIDE_V_NK
,
STRIDE_V_HK
,
STRIDE_V_D
,
STRIDE_WO_HQ
,
STRIDE_WO_D
,
STRIDE_WO_HID
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
Hk
:
tl
.
constexpr
,
Hq
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
HIDDEN
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
"""对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
ks
=
tl
.
program_id
(
2
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
nk_off
=
ks
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
nk
=
k_beg
+
nk_off
k_mask
=
nk
<
k_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
l1_sum
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
g
in
range
(
QUERY_GROUP_SIZE
):
hq
=
hk
*
QUERY_GROUP_SIZE
+
g
v_ptrs
=
(
V
+
nk
[:,
None
]
*
STRIDE_V_NK
+
hk
*
STRIDE_V_HK
+
tl
.
arange
(
0
,
D
)[
None
,
:]
*
STRIDE_V_D
)
v_blk
=
tl
.
load
(
v_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
).
to
(
tl
.
float32
)
for
hid_off
in
range
(
0
,
HIDDEN
,
BLOCK_D
):
hid_idx
=
hid_off
+
tl
.
arange
(
0
,
BLOCK_D
)
hid_mask
=
hid_idx
<
HIDDEN
wo_ptrs
=
(
WO
+
hq
*
STRIDE_WO_HQ
+
tl
.
arange
(
0
,
D
)[:,
None
]
*
STRIDE_WO_D
+
hid_idx
[
None
,
:]
*
STRIDE_WO_HID
)
wo_tile
=
tl
.
load
(
wo_ptrs
,
mask
=
hid_mask
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
wov_tile
=
tl
.
dot
(
v_blk
,
wo_tile
)
l1_sum
+=
tl
.
sum
(
tl
.
abs
(
wov_tile
),
axis
=
1
)
l1_sum
=
l1_sum
/
QUERY_GROUP_SIZE
tl
.
store
(
out_ptrs
,
l1_sum
,
mask
=
k_mask
)
# ============================================================================
# Triton:Stage 1 保护 + Stage 2 加权融合(逐元素)
# ============================================================================
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
,
256
]],
key
=
[
"Hk"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_critical_ada_fuse_kernel
(
BASE_SCORES
,
WO_V_NORM
,
STAGE1_MASK
,
cu_k
,
OUT
,
STRIDE_BS_NK
,
STRIDE_BS_HK
,
STRIDE_WN_NK
,
STRIDE_WN_HK
,
STRIDE_S1_NK
,
STRIDE_S1_HK
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
EPSILON
:
tl
.
constexpr
,
Hk
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
bs_ptrs
=
BASE_SCORES
+
nk
*
STRIDE_BS_NK
+
hk
*
STRIDE_BS_HK
wn_ptrs
=
WO_V_NORM
+
nk
*
STRIDE_WN_NK
+
hk
*
STRIDE_WN_HK
s1_ptrs
=
STAGE1_MASK
+
nk
*
STRIDE_S1_NK
+
hk
*
STRIDE_S1_HK
base
=
tl
.
load
(
bs_ptrs
,
mask
=
kmask
,
other
=
0.0
)
wnorm
=
tl
.
load
(
wn_ptrs
,
mask
=
kmask
,
other
=
1.0
)
stage1_protect
=
tl
.
load
(
s1_ptrs
,
mask
=
kmask
,
other
=
0
).
to
(
tl
.
int32
)
fused
=
(
base
+
EPSILON
)
*
wnorm
fused
=
tl
.
where
(
stage1_protect
==
1
,
float
(
"inf"
),
fused
)
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
fused
,
mask
=
kmask
)
def
critical_ada_key_scores
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
wo_weight
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
base_scores
:
torch
.
Tensor
,
compression_ctx
:
Any
,
*
,
store_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
在每条序列上对齐 kvpress ``CriticalAdaKVPress.compress``(整段 ``k_len``、与源实现相同的
top-k / scatter 顺序);仅 base 分数来自 compactor_vllm 的 Compactor/SnapKV。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
"""
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
and
v
.
stride
(
-
1
)
==
1
device
=
q
.
device
_
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
D
==
Dk
and
Hq
%
Hk
==
0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B
=
cu_seqlens
.
numel
()
-
1
G
=
Hq
//
Hk
k_lengths
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
btr
=
compression_ctx
.
batch_tokens_to_retain
assert
btr
is
not
None
and
btr
.
numel
()
==
B
btr
=
btr
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
epsilon
=
compression_ctx
.
critical_ada_epsilon
first_stage_ratio
=
compression_ctx
.
critical_ada_first_stage_ratio
alpha_safeguard
=
float
(
compression_ctx
.
critical_ada_alpha_safeguard
)
alpha_safeguard
=
max
(
0.0
,
min
(
1.0
,
alpha_safeguard
))
if
wo_weight
.
dim
()
==
2
:
hidden_size
,
_
=
wo_weight
.
shape
wo
=
wo_weight
.
transpose
(
0
,
1
).
view
(
Hq
,
D
,
hidden_size
).
contiguous
()
else
:
wo
=
wo_weight
.
contiguous
()
hidden_size
=
wo
.
size
(
-
1
)
wo_v_norm
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
if
B
>
0
and
int
(
k_lengths
.
max
().
item
())
>
0
:
if
_USE_WO_L1_REFERENCE_BACKEND
:
for
b
in
range
(
B
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
if
k_end
<=
k_beg
:
continue
v_seg
=
v
[
k_beg
:
k_end
,
:,
:].
contiguous
()
wo_v_norm
[
k_beg
:
k_end
,
:]
=
_vwl1_norm_kvpress_reference
(
v_seg
,
wo
,
Hk
,
G
)
else
:
def
grid_wo
(
META
):
max_k_len
=
int
(
k_lengths
.
max
().
item
())
return
(
B
,
Hk
,
triton
.
cdiv
(
max_k_len
,
META
[
"BLOCK_K"
]))
_compute_wo_v_l1_kernel
[
grid_wo
](
v
,
wo
,
cu_seqlens
,
wo_v_norm
,
*
v
.
stride
(),
*
wo
.
stride
(),
*
wo_v_norm
.
stride
(),
Hk
=
Hk
,
Hq
=
Hq
,
D
=
D
,
HIDDEN
=
hidden_size
,
QUERY_GROUP_SIZE
=
G
,
)
stage1_mask
=
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
int32
,
device
=
device
)
head_budgets_by_batch
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
if
k_len
==
0
:
head_budgets_by_batch
.
append
(
None
)
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
keep_pairs
=
int
(
btr
[
b
].
item
())
scores_seg
=
base_scores
[
k_beg
:
k_end
,
:]
# 与 kvpress 的 n_kept 一致:每头保留 n_kept 个 token
n_kept_tokens
=
max
(
1
,
keep_pairs
//
Hk
)
n_kept_tokens
=
min
(
n_kept_tokens
,
k_len
)
# kvpress:topk 在「未改动的」scores 上取索引,scatter 只写在副本上,供 head_budgets 用;
# Stage1 仍用原始 scores_seg(见下)。
working
=
scores_seg
.
clone
()
n_safe
=
int
(
n_kept_tokens
*
alpha_safeguard
)
if
n_safe
>
0
:
nk
=
min
(
n_safe
,
k_len
)
for
hk
in
range
(
Hk
):
top_idx
=
torch
.
topk
(
scores_seg
[:,
hk
],
nk
,
sorted
=
True
).
indices
working
[:,
hk
].
scatter_
(
0
,
top_idx
,
float
(
"inf"
))
top_pairs
=
min
(
n_kept_tokens
*
Hk
,
working
.
numel
())
if
top_pairs
<=
0
:
head_budgets_by_batch
.
append
(
None
)
continue
top_idx_flat
=
torch
.
topk
(
working
.
reshape
(
-
1
),
top_pairs
,
sorted
=
False
).
indices
top_head_idx
=
top_idx_flat
%
Hk
head_budgets
=
torch
.
bincount
(
top_head_idx
,
minlength
=
Hk
).
to
(
torch
.
int32
)
head_budgets_by_batch
.
append
(
head_budgets
)
# Stage 1:与 kvpress 相同 — 先 topk(..., M1, sorted=True),再每头取前 phase1 个下标
head_selection_budget_1st
=
(
(
head_budgets
.
to
(
torch
.
float32
)
*
float
(
first_stage_ratio
))
.
to
(
torch
.
int64
)
.
tolist
()
)
M1
=
max
(
head_selection_budget_1st
)
if
head_selection_budget_1st
else
0
if
M1
>
0
:
mk
=
min
(
M1
,
k_len
)
for
hk
in
range
(
Hk
):
phase1_budget
=
int
(
head_selection_budget_1st
[
hk
])
if
phase1_budget
<=
0
:
continue
full_idx
=
torch
.
topk
(
scores_seg
[:,
hk
],
mk
,
sorted
=
True
).
indices
take
=
min
(
phase1_budget
,
mk
)
stage1_mask
[
k_beg
+
full_idx
[:
take
],
hk
]
=
1
final_scores
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
def
grid_fuse
(
_META
):
return
(
B
,
Hk
)
_critical_ada_fuse_kernel
[
grid_fuse
](
base_scores
,
wo_v_norm
,
stage1_mask
,
cu_seqlens
,
final_scores
,
*
base_scores
.
stride
(),
*
wo_v_norm
.
stride
(),
*
stage1_mask
.
stride
(),
*
final_scores
.
stride
(),
Hk
=
Hk
,
EPSILON
=
float
(
epsilon
),
)
# Stage 2(kvpress):对融合后分数先 topk(..., M2, sorted=True),再每头取前 budget 个下标置 inf
for
b
in
range
(
B
):
hb
=
head_budgets_by_batch
[
b
]
if
hb
is
None
:
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
k_len
=
k_end
-
k_beg
if
k_len
<=
0
:
continue
fused_seg
=
final_scores
[
k_beg
:
k_end
,
:]
M2
=
int
(
hb
.
max
().
item
())
if
M2
<=
0
:
continue
mk
=
min
(
M2
,
k_len
)
for
hk
in
range
(
Hk
):
budget
=
int
(
hb
[
hk
].
item
())
if
budget
<=
0
:
continue
full_idx
=
torch
.
topk
(
fused_seg
[:,
hk
],
mk
,
sorted
=
True
).
indices
take
=
min
(
budget
,
mk
)
final_scores
[
k_beg
+
full_idx
[:
take
],
hk
]
=
float
(
"inf"
)
masked_key_indices
=
None
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
if
k_len
==
0
:
continue
keep_pairs
=
int
(
btr
[
b
].
item
())
total_pairs
=
k_len
*
Hk
if
keep_pairs
>=
total_pairs
:
continue
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
n_prune_pairs
=
min
(
total_pairs
-
keep_pairs
,
total_pairs
)
if
n_prune_pairs
<=
0
:
continue
flat_scores
=
final_scores
[
k_beg
:
k_end
,
:].
reshape
(
-
1
)
prune_idx
=
torch
.
topk
(
-
flat_scores
,
min
(
n_prune_pairs
,
flat_scores
.
numel
()),
sorted
=
False
).
indices
batch_idx
=
torch
.
full_like
(
prune_idx
,
b
,
dtype
=
torch
.
int64
)
head_idx
=
prune_idx
%
Hk
seq_idx
=
prune_idx
//
Hk
+
k_beg
if
masked_key_indices
is
None
:
masked_key_indices
=
(
batch_idx
,
head_idx
,
seq_idx
)
else
:
masked_key_indices
=
(
torch
.
cat
([
masked_key_indices
[
0
],
batch_idx
]),
torch
.
cat
([
masked_key_indices
[
1
],
head_idx
]),
torch
.
cat
([
masked_key_indices
[
2
],
seq_idx
]),
)
if
store_stream
is
not
None
:
final_scores
.
record_stream
(
store_stream
)
return
final_scores
,
masked_key_indices
class
CriticalAdaKVCompression
(
BaseCompressionMethod
):
"""
以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
cc
=
context
.
compression_context
base
=
getattr
(
cc
,
"critical_ada_base_scorer"
,
"snapkv"
)
if
cc
is
not
None
else
"compactor"
if
str
(
base
).
lower
()
==
"snapkv"
:
return
SnapKVCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
return
CompactorCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
Optional
[
torch
.
Tensor
],
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
assert
compression_context
is
not
None
base
=
str
(
getattr
(
compression_context
,
"critical_ada_base_scorer"
,
"compactor"
)).
lower
()
if
base
==
"snapkv"
:
base_scores
=
SnapKVCompression
.
post_rope_scoring
(
q
,
k
,
v
,
pre_rope_scores
,
context
)
else
:
# 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
# maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
# 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
if
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
base_scores
=
maybe_execute_in_stream
(
non_causal_attn_scores
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
sm_scale
=
1.0
,
normalize
=
True
,
accum_scores
=
pre_rope_scores
,
context_lens
=
compression_context
.
context_lens
,
protected_first_tokens
=
compression_context
.
protected_first_tokens
,
protected_last_tokens
=
compression_context
.
protected_last_tokens
,
accum_blending
=
0.5
,
)
wo_weight
=
compression_context
.
wo_weight
if
wo_weight
is
None
:
return
base_scores
scores
,
_masked
=
maybe_execute_in_stream
(
critical_ada_key_scores
,
q
,
k
,
v
,
wo_weight
,
context
.
cu_seqlens_q
,
base_scores
,
compression_context
,
STORE_STREAM
=
context
.
STORE_STREAM
,
store_stream
=
context
.
STORE_STREAM
,
)
return
scores
@
staticmethod
def
prepare_layer
(
module
:
torch
.
nn
.
Module
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if
not
hasattr
(
module
,
"o_proj"
)
or
module
.
o_proj
.
weight
is
None
:
return
if
not
hasattr
(
module
,
"num_heads"
)
or
not
hasattr
(
module
,
"head_dim"
):
return
wo_raw
=
module
.
o_proj
.
weight
.
data
hidden_size
,
_
=
wo_raw
.
shape
Hq
=
module
.
num_heads
head_dim
=
module
.
head_dim
wo
=
(
wo_raw
.
transpose
(
0
,
1
)
.
view
(
Hq
,
head_dim
,
hidden_size
)
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
)
module
.
_critical_ada_wo_weight
=
wo
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv.py
deleted
100644 → 0
View file @
2b7160c6
import
math
from
typing
import
Optional
import
torch
import
triton
from
triton
import
language
as
tl
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
DEFAULT_SNAPKV_WINDOW_SIZE
=
64
DEFAULT_SNAPKV_KERNEL_SIZE
=
5
class
SnapKVCompression
(
BaseCompressionMethod
):
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
scores
=
maybe_execute_in_stream
(
query_aware_key_scores
,
q
,
k
,
context
.
cu_seqlens_q
,
context
.
cu_seqlens_k
,
w
=
DEFAULT_SNAPKV_WINDOW_SIZE
,
kernel_size
=
DEFAULT_SNAPKV_KERNEL_SIZE
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
bq
in
[
32
,
64
]
for
bk
in
[
32
,
64
]
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
3
,
4
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"ROWS_MAX"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_lse_and_store_logits_kernel
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
# int32 pointers
out_m
,
out_S
,
# [B, Hk, ROWS_MAX] float32
LOGITS
,
# [Nk, Hk, ROWS_MAX] float32
sm_scale
,
# float
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
ROWS_MAX
,
):
# program ids
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
# row-tile id
# batch segment bounds
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
# rows for this (b,hk)
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale
=
sm_scale
*
1.4426950408889634
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
# map row -> (q_idx, hq_local)
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
offs_d
=
tl
.
arange
(
0
,
D
)
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:]
)
q_rows
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
# Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
k_ptrs
=
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
],
other
=
0.0
)
# [BK, D]
s
=
tl
.
dot
(
q_rows
,
k_blk
.
T
)
*
qk_scale
# [BQ, BK]
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
# Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
causal_ok
=
nk
[
None
,
:]
<=
q_idx
[:,
None
]
s
=
tl
.
where
(
causal_ok
,
s
,
-
float
(
"inf"
))
# store prefix logits only (for marginal probs on prefix keys)
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
store_mask
=
kmask
&
(
nk
<
k_eff_end
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
store_mask
[:,
None
]
&
row_mask
[
None
,
:])
# log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
cur_max
=
tl
.
max
(
s
,
1
)
# [BQ]
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
# store m,S for these rows
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
})
for
bq
in
[
16
,
32
,
64
]
for
bk
in
[
32
,
64
,
128
]
],
key
=
[
"HK"
,
"HQ"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_prefix_probs_kernel
(
cu_k
,
w_b
,
in_m
,
in_S
,
# [B, Hk, ROWS_MAX] f32
LOGITS
,
# [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
PROBS
,
# [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
#
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
STRIDE_PB_NK
,
STRIDE_PB_HK
,
STRIDE_PB_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
row0
in
tl
.
range
(
0
,
rows_b
,
BLOCK_Q
):
r_idx
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
rmask
=
r_idx
<
rows_b
m_ptr
=
in_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_ptr
=
in_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
m
=
tl
.
load
(
m_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
mask
=
rmask
,
other
=-
float
(
"inf"
),
)
S
=
tl
.
load
(
S_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
mask
=
rmask
,
other
=
0.0
)
valid_row
=
S
>
0
m
=
tl
.
where
(
valid_row
,
m
,
0.0
)
S
=
tl
.
where
(
valid_row
,
S
,
1.0
)
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
s_T
=
tl
.
load
(
log_ptrs
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:],
other
=-
float
(
"inf"
)
)
# [BK, BQ]
probs_T
=
tl
.
math
.
exp2
(
s_T
-
m
[
None
,
:])
/
S
[
None
,
:]
probs_T
=
tl
.
where
(
valid_row
[
None
,
:],
probs_T
,
0.0
)
prob_ptrs
=
(
PROBS
+
nk
[:,
None
]
*
STRIDE_PB_NK
+
hk
*
STRIDE_PB_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_PB_R
)
tl
.
store
(
prob_ptrs
,
probs_T
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:])
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue
(
OUT
,
# [Nk, Hk], float32
cu_k
,
w_b
,
# [B+1], [B] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
EPS
:
tl
.
constexpr
,
# e.g., 1e-12
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
k_eff_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_eff_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
+
EPS
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_T"
:
bt
})
for
bt
in
[
32
,
64
,
128
,
256
]],
key
=
[
"KERNEL_SIZE"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_snapkv_avg_pool1d_kernel
(
IN
,
OUT
,
Lp
,
STRIDE_IN_C
,
STRIDE_IN_L
,
STRIDE_OUT_C
,
STRIDE_OUT_L
,
KERNEL_SIZE
:
tl
.
constexpr
,
PAD
:
tl
.
constexpr
,
BLOCK_T
:
tl
.
constexpr
,
):
"""
Symmetric 1D average pool on the last dimension, matching
`F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
(equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
"""
c
=
tl
.
program_id
(
0
)
t0
=
tl
.
program_id
(
1
)
*
BLOCK_T
+
tl
.
arange
(
0
,
BLOCK_T
)
mask
=
t0
<
Lp
acc
=
tl
.
zeros
([
BLOCK_T
],
dtype
=
tl
.
float32
)
for
j
in
tl
.
static_range
(
KERNEL_SIZE
):
idx
=
t0
-
PAD
+
j
valid
=
(
idx
>=
0
)
&
(
idx
<
Lp
)
ptrs
=
IN
+
c
*
STRIDE_IN_C
+
idx
*
STRIDE_IN_L
v
=
tl
.
load
(
ptrs
,
mask
=
valid
&
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
v
acc
=
acc
/
tl
.
cast
(
KERNEL_SIZE
,
tl
.
float32
)
out_ptrs
=
OUT
+
c
*
STRIDE_OUT_C
+
t0
*
STRIDE_OUT_L
tl
.
store
(
out_ptrs
,
acc
,
mask
=
mask
)
def
_snapkv_avg_pool1d_triton
(
x
:
torch
.
Tensor
,
kernel_size
:
int
)
->
torch
.
Tensor
:
"""
kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
`x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
"""
assert
x
.
dtype
==
torch
.
float32
Hk
,
G
,
Lp
=
x
.
shape
if
Lp
==
0
:
return
x
pad
=
kernel_size
//
2
x2
=
x
.
reshape
(
Hk
*
G
,
Lp
).
contiguous
()
out
=
torch
.
empty_like
(
x2
)
C
=
Hk
*
G
si_c
,
si_l
=
x2
.
stride
()
so_c
,
so_l
=
out
.
stride
()
def
grid
(
meta
):
return
(
C
,
triton
.
cdiv
(
Lp
,
meta
[
"BLOCK_T"
]))
_snapkv_avg_pool1d_kernel
[
grid
](
x2
,
out
,
Lp
,
si_c
,
si_l
,
so_c
,
so_l
,
KERNEL_SIZE
=
kernel_size
,
PAD
=
pad
,
)
return
out
.
view
(
Hk
,
G
,
Lp
)
def
_snapkv_kvpress_epilogue
(
probs_buf
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
G
:
int
,
Hk
:
int
,
kernel_size
:
int
,
)
->
None
:
"""
Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
→ mean over GQA groups → pad tail with global max of prefix scores.
"""
B
=
cu_seqlens_k
.
numel
()
-
1
for
b
in
range
(
B
):
k_beg
=
int
(
cu_seqlens_k
[
b
].
item
())
k_end
=
int
(
cu_seqlens_k
[
b
+
1
].
item
())
win
=
int
(
w
[
b
].
item
())
k_eff_end
=
k_end
-
win
if
win
<=
0
or
k_eff_end
<=
k_beg
:
continue
Lp
=
k_eff_end
-
k_beg
rows_b
=
win
*
G
p
=
probs_buf
[
k_beg
:
k_eff_end
,
:,
:
rows_b
]
# [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
x
=
p
.
view
(
Lp
,
Hk
,
win
,
G
).
mean
(
dim
=
2
)
x
=
x
.
permute
(
1
,
2
,
0
).
contiguous
()
# [Hk, G, Lp]
x
=
_snapkv_avg_pool1d_triton
(
x
,
kernel_size
)
x
=
x
.
mean
(
dim
=
1
)
seg
=
x
.
permute
(
1
,
0
).
contiguous
()
out
[
k_beg
:
k_eff_end
,
:]
=
seg
pad_val
=
seg
.
max
()
out
[
k_eff_end
:
k_end
,
:]
=
pad_val
def
query_aware_key_scores
(
q
:
torch
.
Tensor
,
# [N_q, Hq, D]
k
:
torch
.
Tensor
,
# [N_k, Hk, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1], int32
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1], int32
w
:
torch
.
Tensor
|
int
,
# [B], int32
sm_scale
:
float
=
None
,
# defaults to 1/sqrt(D)
*
,
kernel_size
:
int
=
DEFAULT_SNAPKV_KERNEL_SIZE
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
normalize
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
,
"last dim must be contiguous"
device
=
q
.
device
N_q
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
(
Hq
%
Hk
)
==
0
,
"Hq must be a multiple of Hk"
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
==
cu_seqlens_k
.
numel
()
-
1
G
=
Hq
//
Hk
if
type
(
w
)
is
int
:
max_w
=
w
w
=
torch
.
full
((
B
,),
fill_value
=
w
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
max_w
=
int
(
w
.
max
().
item
())
assert
w
.
numel
()
==
B
ROWS_MAX
=
max_w
*
G
if
ROWS_MAX
==
0
:
return
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
out
=
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
m_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
S_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
logits_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
probs_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
# strides
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
_
=
q
.
stride
()
STRIDE_K_NK
,
STRIDE_K_HK
,
_
=
k
.
stride
()
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
=
m_scratch
.
stride
()
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
=
S_scratch
.
stride
()
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
=
logits_buf
.
stride
()
STRIDE_PB_NK
,
STRIDE_PB_HK
,
STRIDE_PB_R
=
probs_buf
.
stride
()
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
out
.
stride
()
def
grid
(
META
):
return
B
,
Hk
,
triton
.
cdiv
(
ROWS_MAX
,
META
[
"BLOCK_Q"
])
_lse_and_store_logits_kernel
[
grid
](
q
,
k
,
cu_seqlens_q
,
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
ROWS_MAX
=
ROWS_MAX
,
)
_prefix_probs_kernel
[(
B
,
Hk
)](
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
probs_buf
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
STRIDE_PB_NK
=
STRIDE_PB_NK
,
STRIDE_PB_HK
=
STRIDE_PB_HK
,
STRIDE_PB_R
=
STRIDE_PB_R
,
)
_snapkv_kvpress_epilogue
(
probs_buf
,
out
,
cu_seqlens_k
,
w
,
G
,
Hk
,
kernel_size
)
if
normalize
:
_zscore_per_batch_epilogue
[(
B
,)](
out
,
cu_seqlens_k
,
w
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
=
Hk
,
EPS
=
1e-12
,
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
accum_scores
.
mul_
(
accum_blending
)
accum_scores
.
add_
(
out
)
return
accum_scores
else
:
return
out
vllm/compactor-vllm/src/compactor_vllm/compression/snapkv_origin.py
deleted
100644 → 0
View file @
2b7160c6
import
math
from
typing
import
Optional
import
torch
import
triton
from
triton
import
language
as
tl
from
compactor_vllm.compression.common
import
BaseCompressionMethod
from
compactor_vllm.utils.helpers
import
maybe_execute_in_stream
from
compactor_vllm.utils.triton_compat
import
autotune
as
triton_autotune
class
SnapKVCompression
(
BaseCompressionMethod
):
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
scores
=
maybe_execute_in_stream
(
query_aware_key_scores
,
q
,
k
,
context
.
cu_seqlens_q
,
context
.
cu_seqlens_k
,
w
=
32
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
bq
in
[
32
,
64
]
for
bk
in
[
32
,
64
]
for
num_warps
in
[
4
,
8
]
for
num_stages
in
[
3
,
4
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"ROWS_MAX"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_lse_and_store_logits_kernel
(
Q
,
K
,
cu_q
,
cu_k
,
w_b
,
# int32 pointers
out_m
,
out_S
,
# [B, Hk, ROWS_MAX] float32
LOGITS
,
# [Nk, Hk, ROWS_MAX] float32
sm_scale
,
# float
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
STRIDE_K_NK
,
STRIDE_K_HK
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
ROWS_MAX
,
):
# program ids
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
rid
=
tl
.
program_id
(
2
)
# row-tile id
# batch segment bounds
q_end
=
tl
.
load
(
cu_q
+
b
+
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
q_win_beg
=
q_end
-
win
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
# rows for this (b,hk)
rows_b
=
win
*
QUERY_GROUP_SIZE
row0
=
rid
*
BLOCK_Q
if
row0
>=
rows_b
:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale
=
sm_scale
*
1.4426950408889634
offs_qrow
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
row_mask
=
offs_qrow
<
rows_b
# map row -> (q_idx, hq_local)
hq_local
=
offs_qrow
%
QUERY_GROUP_SIZE
q_off
=
offs_qrow
//
QUERY_GROUP_SIZE
q_idx
=
q_win_beg
+
q_off
hq_glob
=
hk
*
QUERY_GROUP_SIZE
+
hq_local
offs_d
=
tl
.
arange
(
0
,
D
)
q_ptrs
=
(
Q
+
q_idx
[:,
None
]
*
STRIDE_Q_NQ
+
hq_glob
[:,
None
]
*
STRIDE_Q_HQ
+
offs_d
[
None
,
:]
)
q_rows
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
m
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
+
(
-
float
(
"inf"
))
S
=
tl
.
zeros
([
BLOCK_Q
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
k_ptrs
=
K
+
nk
[:,
None
]
*
STRIDE_K_NK
+
hk
*
STRIDE_K_HK
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
kmask
[:,
None
],
other
=
0.0
)
# [BK, D]
s
=
tl
.
dot
(
q_rows
,
k_blk
.
T
)
*
qk_scale
# [BQ, BK]
s
=
tl
.
where
(
kmask
[
None
,
:],
s
,
-
float
(
"inf"
))
# store into LOGITS[nk, hk, row] -> [BK, BQ]
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
tl
.
store
(
log_ptrs
,
s
.
T
,
mask
=
kmask
[:,
None
]
&
row_mask
[
None
,
:])
# log2 streaming LSE update
cur_max
=
tl
.
max
(
s
,
1
)
# [BQ]
n_m
=
tl
.
maximum
(
m
,
cur_max
)
rescale
=
tl
.
math
.
exp2
(
m
-
n_m
)
S
=
S
*
rescale
+
tl
.
sum
(
tl
.
math
.
exp2
(
s
-
n_m
[:,
None
]),
1
)
m
=
n_m
# store m,S for these rows
m_base
=
out_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_base
=
out_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
tl
.
store
(
m_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
m
,
mask
=
row_mask
)
tl
.
store
(
S_base
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
S
,
mask
=
row_mask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_Q"
:
bq
,
"BLOCK_K"
:
bk
})
for
bq
in
[
16
,
32
,
64
]
for
bk
in
[
32
,
64
,
128
]
],
key
=
[
"HK"
,
"HQ"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_scores_from_logits_kernel
(
cu_k
,
w_b
,
in_m
,
in_S
,
# [B, Hk, ROWS_MAX] f32
LOGITS
,
# [Nk, Hk, ROWS_MAX] f32, base-2 logits
OUT
,
# [Nk, Hk] f32
#
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
,
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
,
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
BLOCK_Q
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
#
DO_POOL
:
tl
.
constexpr
,
# set True to enable in-place avg pool
KPOOL
:
tl
.
constexpr
,
# kernel size for avg pool (stride=1)
):
b
=
tl
.
program_id
(
0
)
hk
=
tl
.
program_id
(
1
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
(
win
<=
0
)
or
(
k_eff_end
<=
k_beg
):
return
rows_b
=
win
*
QUERY_GROUP_SIZE
# === scores over computed region ===
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
scores
=
tl
.
zeros
([
BLOCK_K
],
dtype
=
tl
.
float32
)
for
row0
in
tl
.
range
(
0
,
rows_b
,
BLOCK_Q
):
r_idx
=
row0
+
tl
.
arange
(
0
,
BLOCK_Q
)
rmask
=
r_idx
<
rows_b
# load m, S for rows
m_ptr
=
in_m
+
b
*
STRIDE_M_B
+
hk
*
STRIDE_M_H
+
row0
*
STRIDE_M_R
S_ptr
=
in_S
+
b
*
STRIDE_S_B
+
hk
*
STRIDE_S_H
+
row0
*
STRIDE_S_R
m
=
tl
.
load
(
m_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_M_R
,
mask
=
rmask
,
other
=-
float
(
"inf"
),
)
S
=
tl
.
load
(
S_ptr
+
tl
.
arange
(
0
,
BLOCK_Q
)
*
STRIDE_S_R
,
mask
=
rmask
,
other
=
0.0
)
valid_row
=
S
>
0
m
=
tl
.
where
(
valid_row
,
m
,
0.0
)
S
=
tl
.
where
(
valid_row
,
S
,
1.0
)
# load stored logits^T: [BK, BQ]
log_ptrs
=
(
LOGITS
+
nk
[:,
None
]
*
STRIDE_LG_NK
+
hk
*
STRIDE_LG_HK
+
(
row0
+
tl
.
arange
(
0
,
BLOCK_Q
))[
None
,
:]
*
STRIDE_LG_R
)
s_T
=
tl
.
load
(
log_ptrs
,
mask
=
kmask
[:,
None
]
&
rmask
[
None
,
:],
other
=-
float
(
"inf"
)
)
# [BK, BQ]
# probs^T = exp2(s_T - m) / S, sum over rows
probs_T
=
tl
.
math
.
exp2
(
s_T
-
m
[
None
,
:])
/
S
[
None
,
:]
probs_T
=
tl
.
where
(
valid_row
[
None
,
:],
probs_T
,
0.0
)
scores
+=
tl
.
sum
(
probs_T
,
1
)
# [BK]
if
DO_POOL
and
(
KPOOL
>
1
):
i
=
tl
.
arange
(
0
,
BLOCK_K
)[:,
None
]
j
=
tl
.
arange
(
0
,
BLOCK_K
)[
None
,
:]
band
=
(
j
<=
i
)
&
((
i
-
j
)
<
KPOOL
)
band
=
band
&
kmask
[
None
,
:]
# sum within band
sums
=
tl
.
sum
(
tl
.
where
(
band
,
scores
[
None
,
:],
0.0
),
1
)
# [BK]
denom
=
tl
.
sum
(
band
,
1
).
to
(
tl
.
float32
)
# [BK]
denom
=
tl
.
where
(
denom
>
0
,
denom
,
1.0
)
scores
=
sums
/
denom
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
scores
,
mask
=
kmask
)
pad_beg
=
k_eff_end
pad_end
=
k_end
if
pad_end
>
pad_beg
:
for
ks
in
tl
.
range
(
pad_beg
,
pad_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
pad_end
out_ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
hk
*
STRIDE_OUT_HK
tl
.
store
(
out_ptrs
,
tl
.
full
([
BLOCK_K
],
float
(
"inf"
),
dtype
=
tl
.
float32
),
mask
=
kmask
)
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue
(
OUT
,
# [Nk, Hk], float32
cu_k
,
w_b
,
# [B+1], [B] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
EPS
:
tl
.
constexpr
,
# e.g., 1e-12
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
win
=
tl
.
load
(
w_b
+
b
)
k_eff_end
=
k_end
-
win
if
k_eff_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_eff_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
+
EPS
)
for
ks
in
tl
.
range
(
k_beg
,
k_eff_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_eff_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
def
query_aware_key_scores
(
q
:
torch
.
Tensor
,
# [N_q, Hq, D]
k
:
torch
.
Tensor
,
# [N_k, Hk, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1], int32
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1], int32
w
:
torch
.
Tensor
|
int
,
# [B], int32
sm_scale
:
float
=
None
,
# defaults to 1/sqrt(D)
*
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
normalize
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
assert
q
.
stride
(
-
1
)
==
1
and
k
.
stride
(
-
1
)
==
1
,
"last dim must be contiguous"
device
=
q
.
device
N_q
,
Hq
,
D
=
q
.
shape
N_k
,
Hk
,
Dk
=
k
.
shape
assert
(
Hq
%
Hk
)
==
0
,
"Hq must be a multiple of Hk"
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
==
cu_seqlens_k
.
numel
()
-
1
G
=
Hq
//
Hk
if
type
(
w
)
is
int
:
max_w
=
w
w
=
torch
.
full
((
B
,),
fill_value
=
w
,
device
=
device
,
dtype
=
torch
.
int32
)
else
:
max_w
=
int
(
w
.
max
().
item
())
assert
w
.
numel
()
==
B
ROWS_MAX
=
max_w
*
G
if
ROWS_MAX
==
0
:
return
torch
.
zeros
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
out
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
m_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
S_scratch
=
torch
.
empty
((
B
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
logits_buf
=
torch
.
empty
((
N_k
,
Hk
,
ROWS_MAX
),
dtype
=
torch
.
float32
,
device
=
device
)
# strides
STRIDE_Q_NQ
,
STRIDE_Q_HQ
,
_
=
q
.
stride
()
STRIDE_K_NK
,
STRIDE_K_HK
,
_
=
k
.
stride
()
STRIDE_M_B
,
STRIDE_M_H
,
STRIDE_M_R
=
m_scratch
.
stride
()
STRIDE_S_B
,
STRIDE_S_H
,
STRIDE_S_R
=
S_scratch
.
stride
()
STRIDE_LG_NK
,
STRIDE_LG_HK
,
STRIDE_LG_R
=
logits_buf
.
stride
()
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
out
.
stride
()
def
grid
(
META
):
return
B
,
Hk
,
triton
.
cdiv
(
ROWS_MAX
,
META
[
"BLOCK_Q"
])
_lse_and_store_logits_kernel
[
grid
](
q
,
k
,
cu_seqlens_q
,
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
sm_scale
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
D
=
D
,
STRIDE_Q_NQ
=
STRIDE_Q_NQ
,
STRIDE_Q_HQ
=
STRIDE_Q_HQ
,
STRIDE_K_NK
=
STRIDE_K_NK
,
STRIDE_K_HK
=
STRIDE_K_HK
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
ROWS_MAX
=
ROWS_MAX
,
)
_scores_from_logits_kernel
[(
B
,
Hk
)](
cu_seqlens_k
,
w
,
m_scratch
,
S_scratch
,
logits_buf
,
out
,
QUERY_GROUP_SIZE
=
Hq
//
Hk
,
STRIDE_M_B
=
STRIDE_M_B
,
STRIDE_M_H
=
STRIDE_M_H
,
STRIDE_M_R
=
STRIDE_M_R
,
STRIDE_S_B
=
STRIDE_S_B
,
STRIDE_S_H
=
STRIDE_S_H
,
STRIDE_S_R
=
STRIDE_S_R
,
STRIDE_LG_NK
=
STRIDE_LG_NK
,
STRIDE_LG_HK
=
STRIDE_LG_HK
,
STRIDE_LG_R
=
STRIDE_LG_R
,
STRIDE_OUT_NK
=
STRIDE_OUT_NK
,
STRIDE_OUT_HK
=
STRIDE_OUT_HK
,
DO_POOL
=
True
,
KPOOL
=
5
,
)
if
normalize
:
_zscore_per_batch_epilogue
[(
B
,)](
out
,
cu_seqlens_k
,
w
,
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
=
Hk
,
EPS
=
1e-12
,
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
accum_scores
.
mul_
(
accum_blending
)
accum_scores
.
add_
(
out
)
return
accum_scores
else
:
return
out
vllm/compactor-vllm/src/compactor_vllm/config/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/config/constants.py
deleted
100644 → 0
View file @
2b7160c6
RESERVED_BATCH
=
0
# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
TRITON_RESERVED_BATCH
=
RESERVED_BATCH
vllm/compactor-vllm/src/compactor_vllm/config/engine_config.py
deleted
100644 → 0
View file @
2b7160c6
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
List
,
Optional
from
transformers
import
AutoConfig
class
AttentionBackend
(
Enum
):
FLASH_ATTENTION
=
auto
()
COMPACTOR_TRITON
=
auto
()
@
dataclass
class
LLMConfig
:
"""Configuration for the :class:`LLM` engine.
Parameters
----------
model : str
Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
a local model name that can be resolved by
:func:`transformers.AutoConfig.from_pretrained`.
path : str, optional
Local directory containing the model weights. If ``None``, the engine
will attempt to resolve a local snapshot for ``model`` using
:func:`huggingface_hub.snapshot_download`.
max_num_seqs : int, default 256
Upper bound on the number of concurrent batches that the scheduler and
KV-cache manager are allowed to handle. This affects the size of the
page table and some internal buffers.
max_model_len : int, default 40960
Maximum context length (in tokens) that the engine will allocate KV cache
and CUDA graphs for. During initialization this value is clamped to
``hf_config.max_position_embeddings`` for the chosen model.
gpu_memory_utilization : float, default 0.9
Fraction of the total GPU memory that may be used for KV cache and model
activations. Values should be in ``(0, 1]``. If this budget is too small,
the KV-cache manager may raise an error at warmup time due
to insufficient memory.
tensor_parallel_size : int, default 1
Number of tensor-parallel workers to shard the model
across. Must be between 1 and 8, and must evenly divide the model's
number of key/value heads.
enforce_eager : bool, default False
If ``True``, disable CUDA graph capture and always run the model in
eager mode during decoding. This reduces throughput. When ``False``,
the engine will capture and reuse CUDA graphs for supported
batch sizes and sequence lengths.
hf_config : transformers.AutoConfig, optional
Pre-loaded Hugging Face configuration for the model. If ``None``,
it will then be populated automatically based on ``model``.
eos : int, default -1
Primary stop token id (warmup / single-id paths). If ``-1``, the
:class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
tokenizer.
eos_token_ids : list of int, optional
All token ids that terminate generation (e.g. HF tokenizers may expose
``eos_token_id`` as a list for chat models). If ``None``, inferred in
:class:`LLM` from the tokenizer and model type.
kvcache_page_size : int, default 128
Number of tokens stored in a single KV-cache page. Smaller pages improve
allocation flexibility but increase page-table overhead; larger pages
reduce overhead but have coarser granularity.
leverage_sketch_size : int, default 48
Sketch dimension used by the Compactor leverage-score estimator.
attention_backend : AttentionBackend, default AttentionBackend.COMPACTOR_TRITON
Attention implementation to use. ``COMPACTOR_TRITON`` selects the custom
Triton kernels used by Compactor; ``FLASH_ATTENTION`` selects the
FlashAttention3 varlen backend. The COMPACTOR_TRITON tends to be faster
for longer sequence lengths, while FA3 is faster at shorter lengths.
"""
model
:
str
path
:
Optional
[
str
]
=
None
nccl_port
:
Optional
[
int
]
=
1218
max_num_seqs
:
int
=
256
max_model_len
:
int
=
40960
gpu_memory_utilization
:
float
=
0.9
tensor_parallel_size
:
int
=
1
enforce_eager
:
bool
=
False
hf_config
:
AutoConfig
|
None
=
None
eos
:
int
=
-
1
eos_token_ids
:
Optional
[
List
[
int
]]
=
None
kvcache_page_size
:
int
=
128
leverage_sketch_size
:
int
=
48
attention_backend
:
AttentionBackend
=
AttentionBackend
.
COMPACTOR_TRITON
show_progress_bar
:
bool
=
True
def
__post_init__
(
self
):
if
self
.
path
is
not
None
and
not
os
.
path
.
isdir
(
self
.
path
):
raise
NotADirectoryError
(
f
"Engine config dir
{
self
.
path
}
does not exist"
)
if
self
.
tensor_parallel_size
<=
0
or
self
.
tensor_parallel_size
>
8
:
assert
1
<=
self
.
tensor_parallel_size
<=
8
raise
ValueError
(
"tensor_parallel_size must be >= 1 and <= 8"
)
if
self
.
hf_config
is
None
:
self
.
hf_config
=
AutoConfig
.
from_pretrained
(
self
.
model
)
self
.
max_model_len
=
min
(
self
.
max_model_len
,
self
.
hf_config
.
max_position_embeddings
)
vllm/compactor-vllm/src/compactor_vllm/config/sampling_params.py
deleted
100644 → 0
View file @
2b7160c6
from
dataclasses
import
dataclass
@
dataclass
class
SamplingParams
:
temperature
:
float
=
1.0
max_new_tokens
:
int
=
256
def
__post_init__
(
self
):
if
self
.
temperature
<
0
:
raise
ValueError
(
"Temperature cannot be negative"
)
vllm/compactor-vllm/src/compactor_vllm/core/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/core/llm_engine.py
deleted
100644 → 0
View file @
2b7160c6
import
atexit
import
inspect
import
logging
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch.multiprocessing
as
mp
from
compactor_vllm.compression.compression_config
import
(
BatchCompressionParams
,
SequenceCompressionParams
,
)
from
compactor_vllm.config.engine_config
import
LLMConfig
from
compactor_vllm.config.sampling_params
import
SamplingParams
from
compactor_vllm.core.model_runner
import
ModelRunner
from
compactor_vllm.models
import
MODEL_REGISTRY
from
compactor_vllm.utils.sequence
import
Sequence
from
transformers
import
AutoTokenizer
logger
=
logging
.
getLogger
(
__name__
)
PromptLike
=
Union
[
str
,
List
[
int
]]
def
_infer_stop_token_ids
(
tokenizer
,
hf_config
)
->
list
[
int
]:
"""
Build the set of token ids that should end generation.
Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
The engine must not compare generated ids to that list as a single ``int``;
see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
Qwen chat uses ``</think>`` (im_end) as the assistant turn boundary; include it
when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
avoid loose substring matches like ``
\"
end
\"
`` that can tag unrelated tokens.
"""
raw
=
tokenizer
.
eos_token_id
ids
:
list
[
int
]
=
[]
if
isinstance
(
raw
,
(
list
,
tuple
)):
ids
.
extend
(
int
(
x
)
for
x
in
raw
)
elif
raw
is
not
None
:
ids
.
append
(
int
(
raw
))
unk_id
=
getattr
(
tokenizer
,
"unk_token_id"
,
None
)
def
_maybe_add_tid
(
tid
:
int
)
->
None
:
if
not
isinstance
(
tid
,
int
)
or
tid
<
0
:
return
if
unk_id
is
not
None
and
tid
==
unk_id
:
return
if
tid
not
in
ids
:
ids
.
append
(
tid
)
model_type
=
getattr
(
hf_config
,
"model_type"
,
None
)
if
model_type
in
(
"qwen2"
,
"qwen3"
,
"qwen2_moe"
,
"qwen3_moe"
):
enc
=
getattr
(
tokenizer
,
"added_tokens_encoder"
,
None
)
if
isinstance
(
enc
,
dict
):
for
key
,
tid
in
enc
.
items
():
if
isinstance
(
key
,
str
)
and
"im_end"
in
key
:
_maybe_add_tid
(
int
(
tid
))
for
extra
in
getattr
(
tokenizer
,
"additional_special_tokens"
,
[])
or
[]:
if
not
isinstance
(
extra
,
str
)
or
"im_end"
not
in
extra
:
continue
try
:
tid
=
tokenizer
.
convert_tokens_to_ids
(
extra
)
except
(
TypeError
,
ValueError
,
KeyError
):
continue
_maybe_add_tid
(
tid
)
if
not
ids
:
raise
ValueError
(
"Could not infer stop token ids from the tokenizer; set "
"LLMConfig(eos_token_ids=[...]) explicitly."
)
return
ids
def
_merge_apply_chat_template_kwargs
(
tokenizer
,
user_kwargs
:
Optional
[
dict
[
str
,
Any
]],
)
->
dict
[
str
,
Any
]:
"""
Merge user kwargs with defaults for HF chat templates that support them.
Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
the first generated token continues the assistant turn; without it, output
can repeat punctuation / template fragments. `enable_thinking=False` avoids
the Qwen3 reasoning channel when the tokenizer supports it.
"""
out
=
dict
(
user_kwargs
or
{})
try
:
sig
=
inspect
.
signature
(
tokenizer
.
apply_chat_template
)
except
(
TypeError
,
ValueError
):
return
out
if
"add_generation_prompt"
in
sig
.
parameters
and
"add_generation_prompt"
not
in
out
:
out
[
"add_generation_prompt"
]
=
True
if
"enable_thinking"
in
sig
.
parameters
and
"enable_thinking"
not
in
out
:
out
[
"enable_thinking"
]
=
False
return
out
def
_runner_entry
(
config
:
LLMConfig
,
rank
:
int
,
evt
):
runner
=
None
try
:
runner
=
ModelRunner
(
config
,
rank
,
evt
)
runner
.
loop
()
except
Exception
as
e
:
logging
.
exception
(
f
"Rank
{
rank
}
:
{
repr
(
e
)
}
"
)
finally
:
if
runner
is
not
None
:
runner
.
exit
()
class
LLMEngine
:
"""High-level engine coordinating model runners and scheduling"""
def
__init__
(
self
,
config
:
LLMConfig
):
self
.
config
=
config
if
self
.
config
.
hf_config
.
model_type
not
in
MODEL_REGISTRY
:
raise
ValueError
(
f
"Unknown model
{
self
.
config
.
model
}
"
)
if
config
.
path
is
None
:
from
huggingface_hub
import
snapshot_download
self
.
config
.
path
=
snapshot_download
(
repo_id
=
config
.
model
,
local_files_only
=
True
)
logger
.
info
(
f
"Using
{
self
.
config
.
model
}
snapshot @
{
self
.
config
.
path
}
"
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
config
.
model
,
use_fast
=
True
)
if
self
.
config
.
eos_token_ids
is
None
:
if
self
.
config
.
eos
!=
-
1
:
self
.
config
.
eos_token_ids
=
[
int
(
self
.
config
.
eos
)]
else
:
self
.
config
.
eos_token_ids
=
_infer_stop_token_ids
(
self
.
tokenizer
,
self
.
config
.
hf_config
)
else
:
self
.
config
.
eos_token_ids
=
[
int
(
x
)
for
x
in
self
.
config
.
eos_token_ids
]
self
.
config
.
eos_token_ids
=
sorted
(
set
(
self
.
config
.
eos_token_ids
))
if
self
.
config
.
eos
==
-
1
:
self
.
config
.
eos
=
int
(
self
.
config
.
eos_token_ids
[
0
])
else
:
self
.
config
.
eos
=
int
(
self
.
config
.
eos
)
if
self
.
config
.
eos
not
in
self
.
config
.
eos_token_ids
:
self
.
config
.
eos_token_ids
=
sorted
(
self
.
config
.
eos_token_ids
+
[
self
.
config
.
eos
]
)
self
.
ps
=
[]
world_size
=
int
(
self
.
config
.
tensor_parallel_size
)
self
.
events
=
[]
if
world_size
>
1
:
ctx
=
mp
.
get_context
(
"spawn"
)
for
r
in
range
(
1
,
world_size
):
event
=
ctx
.
Event
()
p
=
ctx
.
Process
(
target
=
_runner_entry
,
args
=
(
self
.
config
,
r
,
event
),
daemon
=
True
,
)
p
.
start
()
self
.
ps
.
append
(
p
)
self
.
events
.
append
(
event
)
self
.
master_model_runner
=
ModelRunner
(
self
.
config
,
rank
=
0
,
peer_events
=
self
.
events
)
atexit
.
register
(
self
.
exit
)
def
exit
(
self
):
if
getattr
(
self
,
"_exited"
,
False
):
return
self
.
_exited
=
True
runner
=
getattr
(
self
,
"master_model_runner"
,
None
)
if
runner
is
not
None
:
try
:
runner
.
exit
()
except
Exception
:
logger
.
exception
(
"Failed to exit master ModelRunner cleanly"
)
for
p
in
self
.
ps
:
if
p
.
is_alive
():
p
.
terminate
()
p
.
join
(
timeout
=
1.0
)
if
hasattr
(
self
,
"events"
):
self
.
events
.
clear
()
def
tokenize_prompt
(
self
,
prompt
:
PromptLike
,
**
tokenizer_kwargs
)
->
List
[
int
]:
"""
Turn a raw prompt into token IDs.
"""
if
isinstance
(
prompt
,
str
):
return
self
.
tokenizer
(
prompt
,
**
tokenizer_kwargs
)[
"input_ids"
]
else
:
return
list
(
prompt
)
def
detokenize_prompt
(
self
,
sequences
:
List
[
Sequence
],
**
detokenizer_kwargs
)
->
List
[
str
]:
"""
Turn completed Sequences into strings.
"""
defaults
:
dict
[
str
,
Any
]
=
{
"skip_special_tokens"
:
True
}
merged
=
{
**
defaults
,
**
detokenizer_kwargs
}
return
self
.
tokenizer
.
batch_decode
(
[
s
.
completion_token_ids
for
s
in
sequences
],
**
merged
)
def
_build_sequences
(
self
,
prompts
:
List
[
PromptLike
]
|
PromptLike
,
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
per_sequence_compression_params
:
Optional
[
SequenceCompressionParams
|
List
[
SequenceCompressionParams
]
]
=
None
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
List
[
Sequence
]:
"""
Build Sequence objects from prompts, sampling params, and optional
per-sequence compression parameters.
"""
tokenizer_kwargs
=
{}
if
tokenizer_kwargs
is
None
else
tokenizer_kwargs
if
not
isinstance
(
prompts
,
list
):
prompts
=
[
prompts
]
if
isinstance
(
sampling_params
,
SamplingParams
):
sampling_params_list
:
List
[
SamplingParams
]
=
[
sampling_params
]
*
len
(
prompts
)
else
:
sampling_params_list
=
sampling_params
assert
len
(
sampling_params_list
)
==
len
(
prompts
),
(
"sampling_params list must match prompts length"
)
if
per_sequence_compression_params
is
None
:
compression_params_list
:
List
[
SequenceCompressionParams
]
=
[
SequenceCompressionParams
(
1.0
)
for
_
in
prompts
]
elif
isinstance
(
per_sequence_compression_params
,
SequenceCompressionParams
):
compression_params_list
=
[
per_sequence_compression_params
]
*
len
(
prompts
)
else
:
# list-like
assert
len
(
per_sequence_compression_params
)
==
len
(
prompts
),
(
"per_sequence_compression_params list must match prompts length"
)
compression_params_list
=
list
(
per_sequence_compression_params
)
seqs
:
List
[
Sequence
]
=
[]
for
prompt
,
sparams
,
cparams
in
zip
(
prompts
,
sampling_params_list
,
compression_params_list
):
token_ids
=
self
.
tokenize_prompt
(
prompt
,
**
tokenizer_kwargs
)
if
cparams
.
protected_first_tokens
+
cparams
.
protected_last_tokens
>=
len
(
token_ids
):
cparams
.
compression_ratio
=
1.0
seqs
.
append
(
Sequence
(
prompt_token_ids
=
token_ids
,
sampling_params
=
sparams
,
compression_params
=
cparams
,
)
)
return
seqs
def
generate
(
self
,
prompts
:
List
[
PromptLike
]
|
PromptLike
,
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
batch_compression_params
:
BatchCompressionParams
,
*
,
per_sequence_compression_params
:
Union
[
List
[
SequenceCompressionParams
],
SequenceCompressionParams
]
=
None
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
detokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
return_sequences
:
bool
=
False
,
)
->
List
[
str
]
|
tuple
[
List
[
str
],
List
[
Sequence
]]:
"""
Accept prompts and return completed Sequences.
Args:
:param prompts:
Single prompt or list of prompts, each either a raw text prompt,
or pre-tokenized input IDs.
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Compression settings for this batch.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
prompt in the batch.
:param tokenizer_kwargs:
Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
string prompts.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[Sequence]:
One Sequence per input prompt, with `completion_token_ids`
filled in after generation.
"""
tokenizer_kwargs
=
{}
if
tokenizer_kwargs
is
None
else
tokenizer_kwargs
detokenizer_kwargs
=
{}
if
detokenizer_kwargs
is
None
else
detokenizer_kwargs
seqs
=
self
.
_build_sequences
(
prompts
,
sampling_params
=
sampling_params
,
per_sequence_compression_params
=
per_sequence_compression_params
,
tokenizer_kwargs
=
tokenizer_kwargs
,
)
self
.
master_model_runner
.
generate
(
seqs
,
batch_compression_params
)
output_strings
=
self
.
detokenize_prompt
(
seqs
,
**
detokenizer_kwargs
)
if
return_sequences
:
return
output_strings
,
seqs
return
output_strings
def
generate_chat
(
self
,
messages_batch
:
List
[
List
[
dict
]],
sampling_params
:
SamplingParams
|
List
[
SamplingParams
],
batch_compression_params
:
BatchCompressionParams
,
per_sequence_compression_params
:
Union
[
SequenceCompressionParams
,
List
[
SequenceCompressionParams
]
],
*
,
tokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
detokenizer_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
return_sequences
:
bool
=
False
,
)
->
List
[
str
]
|
tuple
[
List
[
str
],
List
[
Sequence
]]:
"""
Convenience API for chat-style prompts using HF `apply_chat_template`.
Args:
:param messages_batch:
List of conversations, where each conversation is a list of
message dicts like:
{"role": "system" | "user" | "assistant", "content": str}
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Batch Level compression settings. Can set compression_method.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
conversation in the batch.
:param tokenizer_kwargs:
Passed through to `tokenizer.apply_chat_template`.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[str] or tuple[List[str], List[Sequence]]:
One string per conversation.
"""
prompts_token_ids
:
List
[
List
[
int
]]
=
[]
tokenizer_kwargs
=
_merge_apply_chat_template_kwargs
(
self
.
tokenizer
,
tokenizer_kwargs
)
detokenizer_kwargs
=
{}
if
detokenizer_kwargs
is
None
else
detokenizer_kwargs
for
messages
in
messages_batch
:
input_ids
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
True
,
**
tokenizer_kwargs
,
)
if
hasattr
(
input_ids
,
"tolist"
):
input_ids
=
input_ids
.
tolist
()
prompts_token_ids
.
append
(
input_ids
)
return
self
.
generate
(
prompts_token_ids
,
sampling_params
=
sampling_params
,
batch_compression_params
=
batch_compression_params
,
per_sequence_compression_params
=
per_sequence_compression_params
,
tokenizer_kwargs
=
tokenizer_kwargs
,
detokenizer_kwargs
=
detokenizer_kwargs
,
return_sequences
=
return_sequences
,
)
def
generate_from_sequences
(
self
,
seqs
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
,
)
->
List
[
Sequence
]:
"""
Args:
:param seqs:
List of Sequence instances
:param batch_compression_params:
Compression settings.
Returns:
:return List[Sequence]:
Same list, mutated in-place with completions.
"""
self
.
master_model_runner
.
generate
(
seqs
,
batch_compression_params
)
return
seqs
vllm/compactor-vllm/src/compactor_vllm/core/memory_manager.py
deleted
100644 → 0
View file @
2b7160c6
import
logging
from
typing
import
Iterable
,
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
compactor_vllm.config.engine_config
import
LLMConfig
from
compactor_vllm.kv_cache.page_table
import
KVAllocationStatus
,
PagedKVCache
from
torch
import
nn
logger
=
logging
.
getLogger
(
__name__
)
class
KVCacheManager
:
def
__init__
(
self
,
rank
:
int
,
config
:
LLMConfig
):
super
().
__init__
()
hf_config
=
config
.
hf_config
self
.
rank
=
rank
self
.
gpu_frac
=
config
.
gpu_memory_utilization
self
.
page_size
=
config
.
kvcache_page_size
self
.
world_size
=
config
.
tensor_parallel_size
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_model_len
=
config
.
max_model_len
self
.
num_layers
=
hf_config
.
num_hidden_layers
self
.
model_dtype
=
hf_config
.
torch_dtype
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
self
.
max_pages_per_batch
=
(
self
.
max_model_len
+
self
.
page_size
-
1
)
//
self
.
page_size
self
.
num_kv_heads
=
hf_config
.
num_key_value_heads
//
dist
.
get_world_size
()
assert
hf_config
.
num_key_value_heads
%
dist
.
get_world_size
()
==
0
,
(
"world size needs to divide num_kv_heads"
)
self
.
num_pages
=
None
self
.
paged_cache
:
Optional
[
PagedKVCache
]
=
None
self
.
max_batched_tokens
=
None
self
.
seq_id_to_batch
=
{}
def
allocate_sequences
(
self
,
seq_ids
:
List
[
int
],
max_positions
:
List
[
int
]
)
->
(
bool
,
Optional
[
torch
.
Tensor
]):
batch_mapping
=
[]
for
seq_id
,
len_to_alloc
in
zip
(
seq_ids
,
max_positions
):
if
seq_id
not
in
self
.
seq_id_to_batch
:
batch_id
=
self
.
paged_cache
.
new_batch
()
if
batch_id
is
None
:
logger
.
warning
(
"Failed to allocate batch!"
)
return
False
,
None
self
.
seq_id_to_batch
[
seq_id
]
=
int
(
batch_id
)
batch_mapping
.
append
(
self
.
seq_id_to_batch
[
seq_id
])
if
(
alloc_status
:
=
self
.
paged_cache
.
reserve_tokens
(
self
.
seq_id_to_batch
[
seq_id
],
len_to_alloc
)
)
!=
KVAllocationStatus
.
SUCCESS
:
logger
.
warning
(
f
"Failed to allocate pages (
{
alloc_status
}
)!"
)
return
False
,
None
batch_mapping
=
torch
.
as_tensor
(
batch_mapping
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
return
True
,
batch_mapping
def
free_sequences
(
self
,
seq_ids
:
Iterable
[
int
]):
for
seq_id
in
seq_ids
:
global_batch_id
=
self
.
seq_id_to_batch
.
pop
(
seq_id
,
None
)
self
.
paged_cache
.
free_batch
(
global_batch_id
)
def
init_cache
(
self
,
model
:
nn
.
Module
):
self
.
num_pages
=
self
.
get_num_pages
(
self
.
gpu_frac
,
self
.
max_pages_per_batch
)
self
.
paged_cache
=
PagedKVCache
(
num_layers
=
self
.
num_layers
,
H_kv
=
self
.
num_kv_heads
,
head_dim
=
self
.
head_dim
,
page_size
=
self
.
page_size
,
num_pages
=
int
(
self
.
num_pages
),
max_num_batches
=
self
.
max_num_batches
,
device
=
f
"cuda:
{
self
.
rank
}
"
,
dtype
=
self
.
model_dtype
,
max_logical_pages_per_head
=
int
(
self
.
max_pages_per_batch
),
)
self
.
_assign_cache_to_layers
(
model
)
def
_assign_cache_to_layers
(
self
,
model
)
->
None
:
for
layer_index
,
layer
in
enumerate
(
model
.
model
.
layers
):
attn
=
layer
.
self_attn
.
attn
k
,
v
,
pt
,
bh
=
self
.
paged_cache
.
layer_slices
(
layer_index
)
attn
.
k_cache
=
k
attn
.
v_cache
=
v
attn
.
page_table
=
pt
attn
.
bh_seq_lens
=
bh
attn
.
page_size
=
self
.
page_size
def
get_num_pages
(
self
,
frac
:
float
,
n_logical_pages_max
:
int
):
free
,
total
=
torch
.
cuda
.
mem_get_info
()
used
=
total
-
free
stats
=
torch
.
cuda
.
memory_stats
()
peak
=
int
(
stats
[
"allocated_bytes.all.peak"
])
current
=
int
(
stats
[
"allocated_bytes.all.current"
])
bytes_for_kv_budget
=
int
(
total
*
frac
*
0.9
)
-
used
-
peak
+
current
if
bytes_for_kv_budget
<=
0
:
raise
RuntimeError
(
f
"Insufficient memory for KV cache."
f
"Try increasing gpu_memory_utilization (currently
{
frac
:.
2
f
}
)."
)
# page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
int32_sz
=
torch
.
empty
((),
dtype
=
torch
.
int32
).
element_size
()
# 4
page_table_bytes_per_layer
=
(
self
.
max_num_batches
*
self
.
num_kv_heads
*
n_logical_pages_max
*
int32_sz
# page_table
+
self
.
max_num_batches
*
self
.
num_kv_heads
*
int32_sz
)
total_page_table_bytes
=
self
.
num_layers
*
page_table_bytes_per_layer
kv_bytes_net
=
bytes_for_kv_budget
-
total_page_table_bytes
if
kv_bytes_net
<=
0
:
raise
RuntimeError
(
"page-table footprint exceeds KV cache budget. "
f
"reduce max_num_seqs (
{
self
.
max_num_batches
}
) "
f
"or increase kv_cache_mem_fraction (currently
{
frac
:.
2
f
}
)."
)
dtype_sz
=
torch
.
empty
((),
dtype
=
self
.
model_dtype
).
element_size
()
bytes_per_page_across_layers
=
self
.
num_layers
*
(
2
*
self
.
page_size
*
self
.
head_dim
*
dtype_sz
)
return
max
(
1
,
kv_bytes_net
//
bytes_per_page_across_layers
)
def
estimate_max_batched_tokens
(
self
,
warmup_tokens
:
int
,
bytes_used_before_warmup
:
int
,
bytes_peak_after_warmup
:
int
,
)
->
int
:
"""
Estimate the max total number of tokens that can be processed concurrently
without OOM.
"""
assert
warmup_tokens
>
0
,
"warmup_tokens must be > 0"
# activation bytes per token
warmup_delta
=
max
(
0
,
int
(
bytes_peak_after_warmup
)
-
int
(
bytes_used_before_warmup
)
)
bytes_per_token
=
max
(
1
,
(
warmup_delta
+
warmup_tokens
-
1
)
//
warmup_tokens
)
free
,
total
=
torch
.
cuda
.
mem_get_info
()
target
=
int
(
total
*
self
.
gpu_frac
)
used_now
=
int
(
total
-
free
)
# reserve headroom equal to the gap between peak and current allocations seen so far
stats
=
torch
.
cuda
.
memory_stats
()
peak_cur
=
int
(
stats
.
get
(
"allocated_bytes.all.peak"
,
0
))
cur_now
=
int
(
stats
.
get
(
"allocated_bytes.all.current"
,
0
))
cushion
=
max
(
0
,
peak_cur
-
cur_now
)
activation_budget
=
int
(
max
(
0
,
target
-
used_now
-
cushion
)
*
0.95
)
max_tokens_per_batch
=
activation_budget
//
bytes_per_token
max_tokens_in_cache
=
(
self
.
num_pages
*
self
.
page_size
)
//
self
.
num_kv_heads
# round to lower multiple of page size
max_tokens_per_batch
=
(
max_tokens_per_batch
//
self
.
page_size
)
*
self
.
page_size
max_tokens_in_cache
=
(
max_tokens_in_cache
//
self
.
page_size
)
*
self
.
page_size
self
.
max_batched_tokens
=
min
(
max_tokens_in_cache
,
max_tokens_per_batch
)
return
self
.
max_batched_tokens
@
property
def
num_free_batches
(
self
)
->
int
:
return
len
(
self
.
paged_cache
.
free_batches
)
@
property
def
num_free_pages
(
self
)
->
int
:
return
min
(
len
(
fp
)
for
fp
in
self
.
paged_cache
.
free_pages
)
def
reclaim_pages
(
self
,
seq_ids_to_reclaim
:
Iterable
[
int
],
future_reserved_buffer
:
List
[
int
]
|
torch
.
Tensor
,
)
->
int
:
approximate_bytes_freed
=
0
for
i
,
seq_id
in
enumerate
(
seq_ids_to_reclaim
):
batch_idx
=
self
.
seq_id_to_batch
[
seq_id
]
approximate_bytes_freed
+=
self
.
paged_cache
.
reclaim_pages
(
batch_idx
,
future_reserved_buffer
[
i
]
)
return
approximate_bytes_freed
vllm/compactor-vllm/src/compactor_vllm/core/model_runner.py
deleted
100644 → 0
View file @
2b7160c6
import
atexit
import
logging
import
inspect
from
typing
import
List
,
Optional
import
torch
import
torch.distributed
as
dist
from
compactor_vllm.attention.sparse_decode_kernel
import
num_splits_heuristic
from
compactor_vllm.compression.compression_config
import
BatchCompressionParams
from
compactor_vllm.config.constants
import
RESERVED_BATCH
from
compactor_vllm.config.engine_config
import
AttentionBackend
,
LLMConfig
from
compactor_vllm.core.memory_manager
import
KVCacheManager
from
compactor_vllm.core.scheduler
import
Scheduler
from
compactor_vllm.layers.sampler
import
Sampler
from
compactor_vllm.models
import
MODEL_REGISTRY
from
compactor_vllm.utils.arguments
import
(
DecodeBatchArguments
,
DecodeBatchOutput
,
PackedTensorArguments
,
PrefillBatchArguments
,
)
from
compactor_vllm.utils.context
import
CompressionContext
,
reset_context
,
set_context
from
compactor_vllm.utils.sequence
import
Sequence
from
torch.multiprocessing
import
Event
from
tqdm
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
class
ModelRunner
:
"""Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
def
__init__
(
self
,
config
:
LLMConfig
,
rank
:
int
,
batch_ready
:
Optional
[
Event
]
=
None
,
peer_events
:
List
[
Event
]
=
None
,
):
self
.
rank
=
rank
self
.
config
=
config
_dev
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
assert
config
.
eos_token_ids
is
not
None
and
len
(
config
.
eos_token_ids
)
>
0
,
(
"LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
)
self
.
_stop_token_ids
=
torch
.
tensor
(
config
.
eos_token_ids
,
dtype
=
torch
.
int64
,
device
=
_dev
)
hf_config
=
config
.
hf_config
self
.
enforce_eager
=
config
.
enforce_eager
self
.
world_size
=
config
.
tensor_parallel_size
self
.
leverage_sketch_size
=
config
.
leverage_sketch_size
self
.
show_progress_bar
=
config
.
show_progress_bar
self
.
max_num_batches
=
config
.
max_num_seqs
self
.
max_model_len
=
config
.
max_model_len
self
.
num_layers
=
hf_config
.
num_hidden_layers
self
.
model_dtype
=
hf_config
.
torch_dtype
self
.
head_dim
=
getattr
(
hf_config
,
"head_dim"
,
None
)
init_kwargs
=
{}
if
"device_id"
in
inspect
.
signature
(
dist
.
init_process_group
).
parameters
:
init_kwargs
[
"device_id"
]
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dist
.
init_process_group
(
"nccl"
,
f
"tcp://localhost:
{
config
.
nccl_port
}
"
,
world_size
=
self
.
world_size
,
rank
=
rank
,
**
init_kwargs
,
)
torch
.
cuda
.
set_device
(
rank
)
default_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
hf_config
.
torch_dtype
)
torch
.
set_default_device
(
"cuda"
)
model_type
=
hf_config
.
model_type
self
.
model
=
MODEL_REGISTRY
[
model_type
](
hf_config
)
self
.
model
.
load_model
(
config
.
path
,
use_tqdm
=
self
.
is_master
and
self
.
show_progress_bar
)
self
.
sampler
=
Sampler
()
pre_warmup_mem
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.current"
,
0
)
self
.
warmup
(
num_warmup_tokens
=
self
.
max_model_len
,
attention_backend
=
AttentionBackend
.
FLASH_ATTENTION
,
)
post_warmup_peak
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
kv_manager
=
KVCacheManager
(
rank
,
config
)
self
.
kv_manager
.
init_cache
(
self
.
model
)
self
.
store_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
torch
.
cuda
.
Stream
()
torch
.
set_default_device
(
"cpu"
)
torch
.
set_default_dtype
(
default_dtype
)
self
.
batch_ready
=
batch_ready
self
.
peer_events
=
peer_events
if
peer_events
is
not
None
else
[]
self
.
captured_graphs
=
{}
self
.
min_captured_len
=
{}
self
.
max_batched_tokens
=
self
.
kv_manager
.
estimate_max_batched_tokens
(
self
.
max_model_len
,
pre_warmup_mem
,
post_warmup_peak
)
if
self
.
is_master
:
logger
.
info
(
f
"Estimated max batched tokens of
{
self
.
max_batched_tokens
}
"
)
if
self
.
config
.
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
self
.
warmup
(
num_warmup_tokens
=
self
.
max_model_len
,
attention_backend
=
AttentionBackend
.
COMPACTOR_TRITON
,
)
if
not
self
.
enforce_eager
:
bs
=
[
1
<<
i
for
i
in
range
(
self
.
max_num_batches
.
bit_length
())]
for
bs
in
(
tqdm
(
bs
,
desc
=
"Capturing CUDA Graphs"
)
if
self
.
is_master
and
self
.
show_progress_bar
else
bs
):
for
seq_len
in
[
1024
,
4096
,
8192
,
16384
]:
self
.
capture_cudagraph
(
bs
,
seq_len
)
self
.
packed_args
=
PackedTensorArguments
(
rank
=
self
.
rank
,
max_batched_tokens
=
self
.
max_batched_tokens
,
config
=
self
.
config
,
)
atexit
.
register
(
self
.
exit
)
@
torch
.
inference_mode
()
def
warmup
(
self
,
num_warmup_tokens
:
int
,
attention_backend
:
AttentionBackend
):
if
self
.
rank
==
0
:
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
backend_name
=
"Compactor Triton"
else
:
backend_name
=
"Flash"
logger
.
info
(
f
"Warming up with
{
backend_name
}
Attention Backend"
)
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
input_ids
=
torch
.
tensor
(
[
self
.
config
.
eos
]
*
num_warmup_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
positions
=
torch
.
arange
(
num_warmup_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
cu_seqlens_q
=
torch
.
tensor
(
[
0
,
num_warmup_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
tensor
(
[
0
,
num_warmup_tokens
],
device
=
device
,
dtype
=
torch
.
int32
)
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
[
-
1
],
[
num_warmup_tokens
]
)
assert
success
else
:
batch_mapping
=
None
set_context
(
is_prefill
=
True
,
do_compression
=
False
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
num_warmup_tokens
,
max_seqlen_k
=
num_warmup_tokens
,
batch_mapping
=
batch_mapping
,
attention_backend
=
attention_backend
,
)
for
_
in
range
(
2
):
torch
.
cuda
.
reset_peak_memory_stats
()
self
.
model
.
compute_logits
(
self
.
model
(
input_ids
,
positions
))
dist
.
barrier
()
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
self
.
kv_manager
.
paged_cache
.
bh_seq_lens
.
index_fill_
(
1
,
batch_mapping
.
to
(
torch
.
long
),
0
)
reset_context
()
if
attention_backend
==
AttentionBackend
.
COMPACTOR_TRITON
:
self
.
kv_manager
.
free_sequences
([
-
1
])
def
exit
(
self
):
if
getattr
(
self
,
"_exited"
,
False
):
return
self
.
_exited
=
True
try
:
if
hasattr
(
self
,
"captured_graphs"
):
self
.
captured_graphs
.
clear
()
finally
:
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
def
loop
(
self
):
while
True
:
if
self
.
batch_ready
.
wait
(
1.0
):
self
.
_process_batches_peer
()
@
torch
.
inference_mode
()
def
run_prefill
(
self
,
prefill_args
:
PrefillBatchArguments
,
batch_mapping
:
torch
.
Tensor
):
assert
prefill_args
.
B
>
0
and
prefill_args
.
N
>
0
max_bh_len
=
(
self
.
kv_manager
.
paged_cache
.
bh_seq_lens
.
index_select
(
1
,
index
=
batch_mapping
)
.
max
()
.
item
()
)
compression_context
=
CompressionContext
(
compression_method
=
prefill_args
.
compression_method
,
compression_chunk_size
=
prefill_args
.
compression_chunk_size
,
batch_tokens_to_retain
=
prefill_args
.
batch_tokens_to_retain
,
max_tokens_to_retain
=
prefill_args
.
max_tokens_to_retain
,
context_lens
=
prefill_args
.
context_lens
.
tolist
(),
PHI
=
prefill_args
.
PHI
,
sketch_dimension
=
self
.
leverage_sketch_size
,
protected_first_tokens
=
prefill_args
.
protected_first
,
protected_last_tokens
=
prefill_args
.
protected_last
,
compression_ratio
=
prefill_args
.
compression_ratio
,
)
set_context
(
is_prefill
=
True
,
do_compression
=
prefill_args
.
do_compression
,
cu_seqlens_q
=
prefill_args
.
cu_seqlens_q
,
cu_seqlens_k
=
prefill_args
.
cu_seqlens_k
,
max_seqlen_q
=
prefill_args
.
max_seqlen_q
,
max_seqlen_k
=
prefill_args
.
max_seqlen_k
,
batch_mapping
=
batch_mapping
,
max_bh_len
=
max_bh_len
,
compression_context
=
compression_context
,
STORE_STREAM
=
self
.
store_stream
,
attention_backend
=
self
.
config
.
attention_backend
,
)
logits
=
self
.
model
.
compute_logits
(
self
.
model
(
prefill_args
.
input_ids
,
prefill_args
.
positions
)
)
reset_context
()
return
logits
def
maybe_broadcast
(
self
,
tensor
:
torch
.
Tensor
):
if
self
.
world_size
>
1
:
return
dist
.
broadcast
(
tensor
,
src
=
0
)
return
None
def
maybe_release_peers
(
self
,
do_release
=
False
):
if
self
.
world_size
>
1
:
if
self
.
is_master
:
if
do_release
:
for
event
in
self
.
peer_events
:
event
.
clear
()
dist
.
barrier
()
else
:
dist
.
barrier
()
@
torch
.
inference_mode
()
def
generate
(
self
,
all_sequences
:
List
[
Sequence
],
batch_compression_params
:
Optional
[
BatchCompressionParams
]
=
None
,
):
assert
self
.
is_master
,
"generate can only be called on the master process"
for
begin_execution_event
in
self
.
peer_events
:
begin_execution_event
.
set
()
if
batch_compression_params
is
None
:
batch_compression_params
=
BatchCompressionParams
()
self
.
_process_batches_master
(
all_sequences
,
batch_compression_params
)
@
property
def
is_master
(
self
):
return
self
.
rank
==
0
@
torch
.
inference_mode
()
def
_process_batches_master
(
self
,
all_sequences
:
List
[
Sequence
],
batch_compression_params
:
BatchCompressionParams
,
):
assert
self
.
is_master
compression_details
=
f
"Applying Compression Method:
{
batch_compression_params
.
compression_method
}
"
if
any
(
seq
.
compression_params
.
compression_ratio
<
1.0
for
seq
in
all_sequences
):
logger
.
info
(
compression_details
)
scheduler
=
Scheduler
(
all_sequences
=
all_sequences
,
kv_manager
=
self
.
kv_manager
,
use_tqdm
=
self
.
show_progress_bar
,
)
decode_batch
=
DecodeBatchArguments
()
decode_flags
=
torch
.
empty
(
2
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
while
not
scheduler
.
is_finished
():
sequences
=
scheduler
.
get_prefill_batch
()
seq_ids_cpu
=
[
seq
.
seq_id
for
seq
in
sequences
]
scheduler
.
add_running_sequence_ids
(
seq_ids_cpu
,
update_status
=
True
)
temps
=
torch
.
tensor
(
[
s
.
sampling_params
.
temperature
for
s
in
sequences
],
dtype
=
torch
.
float32
,
pin_memory
=
True
,
).
cuda
(
non_blocking
=
True
)
prefill_arguments
=
self
.
packed_args
.
build_prefill_args
(
sequences
,
batch_compression_params
=
batch_compression_params
)
max_ctx_lens
=
(
prefill_arguments
.
max_new_tokens
+
prefill_arguments
.
context_lens
)
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
seq_ids_cpu
,
max_ctx_lens
.
tolist
()
)
assert
success
,
"failed to allocate pages for sequences"
logits
=
self
.
run_prefill
(
prefill_arguments
,
batch_mapping
)
# Must match prefill `positions` dtype (int64). `context_lens` is int32
# from the packed buffer; using int32 here breaks RoPE indexing
# (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
positions
=
prefill_arguments
.
context_lens
.
to
(
dtype
=
torch
.
int64
)
token_ids
=
self
.
sampler
(
logits
,
temps
)
# Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
# reads bh_seq_lens on the default stream and must not race.
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
# TODO: synchronize page counts accross dist
if
self
.
world_size
==
1
:
self
.
kv_manager
.
reclaim_pages
(
seq_ids_cpu
,
prefill_arguments
.
max_new_tokens
)
# with logging_redirect_tqdm():
# logger.info(
# f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
# )
if
scheduler
.
any_pending_sequences
():
num_pending_batches
=
(
0
if
decode_batch
.
token_ids
is
None
else
decode_batch
.
token_ids
.
shape
[
0
]
)
occupancy
=
int
((
num_pending_batches
+
len
(
seq_ids_cpu
))
*
0.66
)
else
:
occupancy
=
-
1
run_decode
=
not
scheduler
.
can_prefill_another_batch
()
decode_batch
=
decode_batch
.
update
(
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
prefill_arguments
.
seq_ids
,
temps
,
occupancy
,
)
if
self
.
world_size
>
1
:
decode_flags
[
0
]
=
int
(
run_decode
)
decode_flags
[
1
]
=
occupancy
self
.
maybe_broadcast
(
decode_flags
)
if
not
run_decode
:
continue
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
decode_output
,
decode_batch
=
self
.
run_decode_loop
(
decode_batch
)
finished_sequence_ids
=
scheduler
.
get_finished_sequence_ids_from_unfinished
(
decode_batch
.
seq_ids
.
tolist
()
)
scheduler
.
record_finished_sequence_ids
(
finished_sequence_ids
,
update_status
=
True
)
self
.
kv_manager
.
free_sequences
(
finished_sequence_ids
)
self
.
maybe_release_peers
(
scheduler
.
is_finished
())
scheduler
.
update_sequences
(
decode_output
.
output_tokens
.
tolist
(),
decode_output
.
output_seq_ids
.
tolist
(),
)
scheduler
.
close
()
@
torch
.
inference_mode
()
def
_process_batches_peer
(
self
):
assert
not
self
.
is_master
scheduler
=
Scheduler
([],
kv_manager
=
self
.
kv_manager
)
decode_batch
=
DecodeBatchArguments
()
decode_flags
=
torch
.
empty
(
2
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
while
self
.
batch_ready
.
is_set
():
prefill_arguments
=
self
.
packed_args
.
build_prefill_args
()
B
=
prefill_arguments
.
B
max_ctx_lens
=
(
prefill_arguments
.
max_new_tokens
+
prefill_arguments
.
context_lens
)
seq_ids_cpu
=
prefill_arguments
.
seq_ids
.
tolist
()
scheduler
.
add_running_sequence_ids
(
seq_ids_cpu
)
success
,
batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
seq_ids_cpu
,
max_ctx_lens
.
tolist
()
)
assert
success
,
"failed to allocate pages for sequences"
self
.
run_prefill
(
prefill_arguments
,
batch_mapping
)
positions
=
prefill_arguments
.
context_lens
.
to
(
dtype
=
torch
.
int64
)
self
.
maybe_broadcast
(
decode_flags
)
run_decode
=
bool
(
decode_flags
[
0
].
item
())
occupancy
=
int
(
decode_flags
[
1
].
item
())
token_ids
=
torch
.
empty
(
B
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
decode_batch
=
decode_batch
.
update
(
batch_mapping
,
token_ids
,
positions
,
max_ctx_lens
,
prefill_arguments
.
seq_ids
,
None
,
# temps not used in peer process
occupancy
,
)
if
not
run_decode
:
continue
if
self
.
store_stream
is
not
None
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
store_stream
)
_
,
decode_batch
=
self
.
run_decode_loop
(
decode_batch
)
finished_sequence_ids
=
scheduler
.
get_finished_sequence_ids_from_unfinished
(
decode_batch
.
seq_ids
.
tolist
()
)
scheduler
.
record_finished_sequence_ids
(
finished_sequence_ids
)
self
.
kv_manager
.
free_sequences
(
finished_sequence_ids
)
self
.
maybe_release_peers
()
scheduler
.
close
()
@
torch
.
inference_mode
()
def
run_decode_loop
(
self
,
decode_batch
:
DecodeBatchArguments
,
)
->
tuple
[
DecodeBatchOutput
,
DecodeBatchArguments
]:
if
self
.
is_master
:
num_stashed_batches
=
decode_batch
.
num_stashed_batches
tok_buffer
=
[
decode_batch
.
token_ids
[
num_stashed_batches
:].
to
(
"cpu"
,
non_blocking
=
True
)
]
seq_buffer
=
[
decode_batch
.
seq_ids
[
num_stashed_batches
:].
to
(
"cpu"
,
non_blocking
=
True
)
]
while
True
:
self
.
maybe_broadcast
(
decode_batch
.
token_ids
)
not_stopped
=
~
torch
.
isin
(
decode_batch
.
token_ids
,
self
.
_stop_token_ids
)
running_batches
=
(
decode_batch
.
positions
<
decode_batch
.
max_ctx_lens
)
&
(
not_stopped
)
decode_batch
.
token_ids
=
torch
.
masked_select
(
decode_batch
.
token_ids
,
running_batches
)
decode_batch
.
positions
=
torch
.
masked_select
(
decode_batch
.
positions
,
running_batches
)
decode_batch
.
batch_mapping
=
torch
.
masked_select
(
decode_batch
.
batch_mapping
,
running_batches
)
decode_batch
.
max_ctx_lens
=
torch
.
masked_select
(
decode_batch
.
max_ctx_lens
,
running_batches
)
decode_batch
.
seq_ids
=
torch
.
masked_select
(
decode_batch
.
seq_ids
,
running_batches
)
if
self
.
is_master
:
decode_batch
.
temps
=
torch
.
masked_select
(
decode_batch
.
temps
,
running_batches
)
num_remaining
=
decode_batch
.
token_ids
.
numel
()
if
(
num_remaining
==
0
or
num_remaining
<=
decode_batch
.
desired_batch_occupancy
):
decode_batch
.
num_stashed_batches
=
num_remaining
break
if
self
.
enforce_eager
:
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
decode_batch
.
batch_mapping
,
)
logits
=
self
.
model
.
compute_logits
(
self
.
model
(
decode_batch
.
token_ids
,
decode_batch
.
positions
)
)
else
:
logits
=
self
.
run_graph_decode
(
decode_batch
.
token_ids
,
decode_batch
.
positions
,
decode_batch
.
batch_mapping
,
)
if
self
.
is_master
:
decode_batch
.
token_ids
=
self
.
sampler
(
logits
,
decode_batch
.
temps
)
tok_buffer
.
append
(
decode_batch
.
token_ids
.
to
(
"cpu"
,
non_blocking
=
True
))
seq_buffer
.
append
(
decode_batch
.
seq_ids
.
to
(
"cpu"
,
non_blocking
=
True
))
decode_batch
.
positions
+=
1
if
self
.
is_master
:
# non_blocking D2H copies must finish before cat/tolist read CPU data.
torch
.
cuda
.
synchronize
()
output
=
DecodeBatchOutput
(
output_tokens
=
torch
.
cat
(
tok_buffer
),
output_seq_ids
=
torch
.
cat
(
seq_buffer
),
)
else
:
output
=
DecodeBatchOutput
(
None
,
None
)
return
output
,
decode_batch
@
torch
.
inference_mode
()
def
run_graph_decode
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
):
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
batch_mapping
,
)
bs
=
input_ids
.
shape
[
0
]
graph_dict
=
self
.
get_cuda_graph
(
bs
,
int
(
positions
.
max
()))
graph_dict
[
"input_ids"
][:
bs
]
=
input_ids
graph_dict
[
"positions"
][:
bs
]
=
positions
graph_dict
[
"batch_mapping"
].
fill_
(
RESERVED_BATCH
)
graph_dict
[
"batch_mapping"
][:
bs
]
=
batch_mapping
graph_dict
[
"graph"
].
replay
()
return
(
graph_dict
[
"logits"
][:
bs
]
if
graph_dict
[
"logits"
]
is
not
None
else
graph_dict
[
"logits"
]
)
@
torch
.
inference_mode
()
def
capture_cudagraph
(
self
,
batch_size
:
int
,
max_seqlen_k
:
int
):
dist
.
barrier
()
device
=
torch
.
device
(
"cuda"
)
logger
.
debug
(
f
"Capturing CUDA graph for batch size
{
batch_size
}
(
{
max_seqlen_k
}
tokens)"
)
_g_input_ids
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
device
)
_g_positions
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
device
)
_g_logits
=
None
key_split
=
num_splits_heuristic
(
batch_size
*
self
.
kv_manager
.
num_kv_heads
,
max_seq_len
=
max_seqlen_k
,
num_sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
,
max_splits
=
12
,
)
success
,
_g_batch_mapping
=
self
.
kv_manager
.
allocate_sequences
(
list
(
range
(
batch_size
)),
[
256
]
*
batch_size
)
assert
success
set_context
(
is_prefill
=
False
,
do_compression
=
False
,
batch_mapping
=
_g_batch_mapping
,
key_split
=
key_split
,
)
# warmup
self
.
model
.
compute_logits
(
self
.
model
(
_g_input_ids
,
_g_positions
))
dist
.
barrier
()
decode_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
decode_graph
):
_g_logits
=
self
.
model
.
compute_logits
(
self
.
model
(
_g_input_ids
,
_g_positions
)
)
graph_vars
=
{
"graph"
:
decode_graph
,
"input_ids"
:
_g_input_ids
,
"positions"
:
_g_positions
,
"batch_mapping"
:
_g_batch_mapping
,
"logits"
:
_g_logits
,
"key_split"
:
key_split
,
}
if
batch_size
not
in
self
.
captured_graphs
:
self
.
captured_graphs
[
batch_size
]
=
{}
self
.
min_captured_len
[
batch_size
]
=
float
(
"inf"
)
self
.
captured_graphs
[
batch_size
][
max_seqlen_k
]
=
graph_vars
self
.
min_captured_len
[
batch_size
]
=
min
(
max_seqlen_k
,
self
.
min_captured_len
[
batch_size
]
)
self
.
kv_manager
.
free_sequences
(
list
(
range
(
batch_size
)))
def
get_cuda_graph
(
self
,
batch_size
:
int
,
max_seqlen_k
:
int
):
batch_size
=
next
(
x
for
x
in
self
.
captured_graphs
.
keys
()
if
x
>=
batch_size
)
batch_size_graphs
=
self
.
captured_graphs
[
batch_size
]
# we want largest seq_len that is smaller than max_seqlen_k
best
=
self
.
min_captured_len
[
batch_size
]
for
seq_len
in
batch_size_graphs
.
keys
():
if
seq_len
<=
max_seqlen_k
:
best
=
max
(
best
,
seq_len
)
return
batch_size_graphs
[
best
]
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment