Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d29c39ca
"vscode:/vscode.git/clone" did not exist on "b89b4d9973c1fbcbfe4f3bf9b53d1dde664d5c5f"
Commit
d29c39ca
authored
Apr 30, 2026
by
chenzk
Browse files
vllm kvprune wo:v1.1.0
parent
f81ce56b
Changes
246
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3242 additions
and
30 deletions
+3242
-30
vllm/kvprune/compression/common.py
vllm/kvprune/compression/common.py
+1
-0
vllm/kvprune/compression/criticalkv.py
vllm/kvprune/compression/criticalkv.py
+54
-25
vllm/kvprune/core/model_runner.py
vllm/kvprune/core/model_runner.py
+38
-3
vllm/kvprune/layers/attention.py
vllm/kvprune/layers/attention.py
+10
-2
vllm/kvprune/models/llama3.py
vllm/kvprune/models/llama3.py
+12
-0
vllm/kvprune/models/qwen3_moe.py
vllm/kvprune/models/qwen3_moe.py
+12
-0
vllm/kvprune/utils/context.py
vllm/kvprune/utils/context.py
+2
-0
vllm/kvprune_legacy_save/__init__.py
vllm/kvprune_legacy_save/__init__.py
+20
-0
vllm/kvprune_legacy_save/attention/__init__.py
vllm/kvprune_legacy_save/attention/__init__.py
+7
-0
vllm/kvprune_legacy_save/attention/compile_kernels.py
vllm/kvprune_legacy_save/attention/compile_kernels.py
+261
-0
vllm/kvprune_legacy_save/attention/fa_paged_bridge.py
vllm/kvprune_legacy_save/attention/fa_paged_bridge.py
+192
-0
vllm/kvprune_legacy_save/attention/sparse_decode_kernel.py
vllm/kvprune_legacy_save/attention/sparse_decode_kernel.py
+401
-0
vllm/kvprune_legacy_save/attention/sparse_varlen_kernel.py
vllm/kvprune_legacy_save/attention/sparse_varlen_kernel.py
+455
-0
vllm/kvprune_legacy_save/benchmark/__init__.py
vllm/kvprune_legacy_save/benchmark/__init__.py
+47
-0
vllm/kvprune_legacy_save/compactor_porting_status.py
vllm/kvprune_legacy_save/compactor_porting_status.py
+56
-0
vllm/kvprune_legacy_save/compression/__init__.py
vllm/kvprune_legacy_save/compression/__init__.py
+41
-0
vllm/kvprune_legacy_save/compression/common.py
vllm/kvprune_legacy_save/compression/common.py
+243
-0
vllm/kvprune_legacy_save/compression/compactor.py
vllm/kvprune_legacy_save/compression/compactor.py
+739
-0
vllm/kvprune_legacy_save/compression/compactor_origin.py
vllm/kvprune_legacy_save/compression/compactor_origin.py
+606
-0
vllm/kvprune_legacy_save/compression/compression_config.py
vllm/kvprune_legacy_save/compression/compression_config.py
+45
-0
No files found.
vllm/kvprune/compression/common.py
View file @
d29c39ca
...
@@ -144,6 +144,7 @@ def extract_and_store_top_kv(
...
@@ -144,6 +144,7 @@ def extract_and_store_top_kv(
padding
:
float
=
-
float
(
"inf"
),
padding
:
float
=
-
float
(
"inf"
),
):
):
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
assert
num_tokens_to_retain
is
not
None
,
"num_tokens_to_retain must be set"
# per_head: per-head highest-scoring remaining tokens for page padding.
# per_head: per-head highest-scoring remaining tokens for page padding.
# global_scan: legacy global ranking order, padded by scanning forward in-kernel.
# global_scan: legacy global ranking order, padded by scanning forward in-kernel.
padding_mode
=
os
.
environ
.
get
(
padding_mode
=
os
.
environ
.
get
(
...
...
vllm/kvprune/compression/criticalkv.py
View file @
d29c39ca
...
@@ -74,7 +74,7 @@ def _vwl1_norm_kvpress_reference(
...
@@ -74,7 +74,7 @@ def _vwl1_norm_kvpress_reference(
for
bk
in
[
32
,
64
,
128
]
for
bk
in
[
32
,
64
,
128
]
for
bd
in
[
32
,
64
]
for
bd
in
[
32
,
64
]
for
nw
in
[
4
,
8
]
for
nw
in
[
4
,
8
]
for
ns
in
[
3
,
4
]
for
ns
in
[
1
]
],
],
key
=
[
"Hk"
,
"D"
,
"HIDDEN"
],
key
=
[
"Hk"
,
"D"
,
"HIDDEN"
],
cache_results
=
True
,
cache_results
=
True
,
...
@@ -139,10 +139,11 @@ def _compute_wo_v_l1_kernel(
...
@@ -139,10 +139,11 @@ def _compute_wo_v_l1_kernel(
)
)
wo_tile
=
tl
.
load
(
wo_ptrs
,
mask
=
hid_mask
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
wo_tile
=
tl
.
load
(
wo_ptrs
,
mask
=
hid_mask
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
wov_tile
=
tl
.
dot
(
v_blk
,
wo_tile
)
wov_tile
=
tl
.
dot
(
v_blk
,
wo_tile
,
input_precision
=
"ieee"
)
l1_sum
+=
tl
.
sum
(
tl
.
abs
(
wov_tile
),
axis
=
1
)
l1_sum
+=
tl
.
sum
(
tl
.
abs
(
wov_tile
),
axis
=
1
)
l1_sum
=
l1_sum
/
QUERY_GROUP_SIZE
l1_sum
=
l1_sum
/
QUERY_GROUP_SIZE
l1_sum
=
tl
.
maximum
(
l1_sum
,
0.0
)
tl
.
store
(
out_ptrs
,
l1_sum
,
mask
=
k_mask
)
tl
.
store
(
out_ptrs
,
l1_sum
,
mask
=
k_mask
)
...
@@ -186,6 +187,11 @@ def critical_ada_key_scores(
...
@@ -186,6 +187,11 @@ def critical_ada_key_scores(
btr
=
compression_ctx
.
batch_tokens_to_retain
btr
=
compression_ctx
.
batch_tokens_to_retain
assert
btr
is
not
None
and
btr
.
numel
()
==
B
assert
btr
is
not
None
and
btr
.
numel
()
==
B
btr
=
btr
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
btr
=
btr
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
btr_effective
=
getattr
(
compression_ctx
,
"effective_batch_tokens_to_retain"
,
None
)
if
btr_effective
is
not
None
:
btr_effective
=
btr_effective
.
to
(
device
=
device
,
dtype
=
torch
.
int32
)
prot_first
=
getattr
(
compression_ctx
,
"protected_first_tokens"
,
None
)
or
[
0
]
*
B
prot_last
=
getattr
(
compression_ctx
,
"protected_last_tokens"
,
None
)
or
[
0
]
*
B
epsilon
=
compression_ctx
.
critical_ada_epsilon
epsilon
=
compression_ctx
.
critical_ada_epsilon
first_stage_ratio
=
compression_ctx
.
critical_ada_first_stage_ratio
first_stage_ratio
=
compression_ctx
.
critical_ada_first_stage_ratio
...
@@ -236,47 +242,71 @@ def critical_ada_key_scores(
...
@@ -236,47 +242,71 @@ def critical_ada_key_scores(
_score_max
=
float
(
torch
.
finfo
(
torch
.
float32
).
max
)
_score_max
=
float
(
torch
.
finfo
(
torch
.
float32
).
max
)
final_scores
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
final_scores
=
torch
.
empty
((
N_k
,
Hk
),
dtype
=
torch
.
float32
,
device
=
device
)
head_budgets_by_batch
:
list
[
Optional
[
torch
.
Tensor
]
]
=
[]
keep_pairs_total_by_batch
:
list
[
int
]
=
[]
for
b
in
range
(
B
):
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
k_len
=
int
(
k_lengths
[
b
].
item
())
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
if
k_len
==
0
:
if
k_len
==
0
:
head_budgets
_by_batch
.
append
(
None
)
keep_pairs_total
_by_batch
.
append
(
0
)
continue
continue
scores_seg
=
base_scores
[
k_beg
:
k_end
,
:].
float
()
scores_seg
=
base_scores
[
k_beg
:
k_end
,
:].
float
()
keep_pairs
=
int
(
btr
[
b
].
item
())
prot_first_b
=
int
(
prot_first
[
b
])
if
b
<
len
(
prot_first
)
else
0
n_kept_tokens
=
max
(
1
,
keep_pairs
//
Hk
)
prot_last_b
=
int
(
prot_last
[
b
])
if
b
<
len
(
prot_last
)
else
0
n_kept_tokens
=
min
(
n_kept_tokens
,
k_len
)
left_keep
=
min
(
prot_first_b
,
k_len
)
right_keep
=
min
(
prot_last_b
,
max
(
0
,
k_len
-
left_keep
))
mid_lo
=
left_keep
mid_hi
=
k_len
-
right_keep
compressible_len
=
max
(
0
,
mid_hi
-
mid_lo
)
keep_pairs_middle
=
min
(
int
(
btr
[
b
].
item
()),
compressible_len
*
Hk
)
keep_pairs_total
=
(
min
(
int
(
btr_effective
[
b
].
item
()),
k_len
*
Hk
)
if
btr_effective
is
not
None
else
min
(
keep_pairs_middle
+
(
left_keep
+
right_keep
)
*
Hk
,
k_len
*
Hk
,
)
)
keep_pairs_total_by_batch
.
append
(
keep_pairs_total
)
final_seg
=
scores_seg
.
clone
()
if
left_keep
>
0
:
final_seg
[:
left_keep
,
:]
=
_score_max
if
right_keep
>
0
:
final_seg
[
mid_hi
:,
:]
=
_score_max
if
compressible_len
<=
0
or
keep_pairs_middle
<=
0
:
final_scores
[
k_beg
:
k_end
,
:]
=
final_seg
continue
n_kept_tokens
=
max
(
1
,
keep_pairs_middle
//
Hk
)
n_kept_tokens
=
min
(
n_kept_tokens
,
compressible_len
)
# scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
# scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
scores_work
=
scores_seg
.
clone
()
scores_work
=
scores_seg
[
mid_lo
:
mid_hi
,
:]
.
clone
()
# --- Alpha safeguard(kvpress L148–152)---
# --- Alpha safeguard(kvpress L148–152)---
n_safe
=
int
(
n_kept_tokens
*
alpha_safeguard
)
n_safe
=
int
(
n_kept_tokens
*
alpha_safeguard
)
nk
=
min
(
n_safe
,
k
_len
)
if
n_safe
>
0
else
0
nk
=
min
(
n_safe
,
compressible
_len
)
if
n_safe
>
0
else
0
if
nk
>
0
:
if
nk
>
0
:
for
hk
in
range
(
Hk
):
for
hk
in
range
(
Hk
):
top_idx
=
torch
.
topk
(
scores_work
[:,
hk
],
nk
,
dim
=
0
,
largest
=
True
).
indices
top_idx
=
torch
.
topk
(
scores_work
[:,
hk
],
nk
,
dim
=
0
,
largest
=
True
).
indices
scores_work
[
top_idx
,
hk
]
=
_score_max
scores_work
[
top_idx
,
hk
]
=
_score_max
# --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
# --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
top_pairs
=
min
(
n_kept_tokens
*
Hk
,
k
_len
*
Hk
)
top_pairs
=
min
(
n_kept_tokens
*
Hk
,
compressible
_len
*
Hk
)
if
top_pairs
<=
0
:
if
top_pairs
<=
0
:
head_budgets_by_batch
.
append
(
None
)
final_scores
[
k_beg
:
k_end
,
:]
=
final_seg
wn
=
wo_v_norm
[
k_beg
:
k_end
,
:]
final_scores
[
k_beg
:
k_end
,
:]
=
(
scores_seg
+
epsilon
)
*
wn
continue
continue
budget_flat
=
scores_work
.
permute
(
1
,
0
).
contiguous
().
reshape
(
-
1
)
budget_flat
=
scores_work
.
permute
(
1
,
0
).
contiguous
().
reshape
(
-
1
)
top_idx_flat
=
torch
.
topk
(
top_idx_flat
=
torch
.
topk
(
budget_flat
,
top_pairs
,
largest
=
True
,
sorted
=
False
budget_flat
,
top_pairs
,
largest
=
True
,
sorted
=
False
).
indices
).
indices
top_head_idx
=
top_idx_flat
//
k
_len
top_head_idx
=
top_idx_flat
//
compressible
_len
head_budgets
=
torch
.
bincount
(
top_head_idx
,
minlength
=
Hk
).
to
(
torch
.
int64
)
head_budgets
=
torch
.
bincount
(
top_head_idx
,
minlength
=
Hk
).
to
(
torch
.
int64
)
head_budgets_by_batch
.
append
(
head_budgets
)
# --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
# --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
head_selection_budget_1st
=
(
head_selection_budget_1st
=
(
...
@@ -285,7 +315,7 @@ def critical_ada_key_scores(
...
@@ -285,7 +315,7 @@ def critical_ada_key_scores(
.
tolist
()
.
tolist
()
)
)
M1
=
max
(
head_selection_budget_1st
)
if
head_selection_budget_1st
else
0
M1
=
max
(
head_selection_budget_1st
)
if
head_selection_budget_1st
else
0
mk
=
min
(
M1
,
k
_len
)
if
M1
>
0
else
0
mk
=
min
(
M1
,
compressible
_len
)
if
M1
>
0
else
0
if
mk
>
0
:
if
mk
>
0
:
top_k_index
=
torch
.
topk
(
scores_work
,
mk
,
dim
=
0
,
largest
=
True
,
sorted
=
True
).
indices
top_k_index
=
torch
.
topk
(
scores_work
,
mk
,
dim
=
0
,
largest
=
True
,
sorted
=
True
).
indices
for
hk
in
range
(
Hk
):
for
hk
in
range
(
Hk
):
...
@@ -296,12 +326,12 @@ def critical_ada_key_scores(
...
@@ -296,12 +326,12 @@ def critical_ada_key_scores(
scores_work
[
top_k_index
[:
take
,
hk
],
hk
]
=
_score_max
scores_work
[
top_k_index
[:
take
,
hk
],
hk
]
=
_score_max
# --- Stage 2 重加权(kvpress L173–175)---
# --- Stage 2 重加权(kvpress L173–175)---
wn
=
wo_v_norm
[
k_beg
:
k_end
,
:]
wn
=
wo_v_norm
[
k_beg
+
mid_lo
:
k_beg
+
mid_hi
,
:]
scores_fused
=
(
scores_work
+
epsilon
)
*
wn
scores_fused
=
(
scores_work
+
epsilon
)
*
wn
# --- Stage 2 scatter(kvpress L176–179)---
# --- Stage 2 scatter(kvpress L176–179)---
M2
=
int
(
head_budgets
.
max
().
item
())
M2
=
int
(
head_budgets
.
max
().
item
())
mk2
=
min
(
M2
,
k
_len
)
if
M2
>
0
else
0
mk2
=
min
(
M2
,
compressible
_len
)
if
M2
>
0
else
0
if
mk2
>
0
:
if
mk2
>
0
:
top_k_index2
=
torch
.
topk
(
top_k_index2
=
torch
.
topk
(
scores_fused
,
mk2
,
dim
=
0
,
largest
=
True
,
sorted
=
True
scores_fused
,
mk2
,
dim
=
0
,
largest
=
True
,
sorted
=
True
...
@@ -313,14 +343,15 @@ def critical_ada_key_scores(
...
@@ -313,14 +343,15 @@ def critical_ada_key_scores(
take
=
min
(
budget
,
mk2
)
take
=
min
(
budget
,
mk2
)
scores_fused
[
top_k_index2
[:
take
,
hk
],
hk
]
=
_score_max
scores_fused
[
top_k_index2
[:
take
,
hk
],
hk
]
=
_score_max
final_scores
[
k_beg
:
k_end
,
:]
=
scores_fused
final_seg
[
mid_lo
:
mid_hi
,
:]
=
scores_fused
final_scores
[
k_beg
:
k_end
,
:]
=
final_seg
masked_key_indices
=
None
masked_key_indices
=
None
for
b
in
range
(
B
):
for
b
in
range
(
B
):
k_len
=
int
(
k_lengths
[
b
].
item
())
k_len
=
int
(
k_lengths
[
b
].
item
())
if
k_len
==
0
:
if
k_len
==
0
:
continue
continue
keep_pairs
=
int
(
btr
[
b
].
item
())
keep_pairs
=
keep_pairs_total_by_batch
[
b
]
total_pairs
=
k_len
*
Hk
total_pairs
=
k_len
*
Hk
if
keep_pairs
>=
total_pairs
:
if
keep_pairs
>=
total_pairs
:
continue
continue
...
@@ -368,9 +399,9 @@ class CriticalAdaKVCompression(BaseCompressionMethod):
...
@@ -368,9 +399,9 @@ class CriticalAdaKVCompression(BaseCompressionMethod):
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
cc
=
context
.
compression_context
cc
=
context
.
compression_context
base
=
(
base
=
(
getattr
(
cc
,
"critical_ada_base_scorer"
,
"
compactor
"
)
getattr
(
cc
,
"critical_ada_base_scorer"
,
"
snapkv
"
)
if
cc
is
not
None
if
cc
is
not
None
else
"
compactor
"
else
"
snapkv
"
)
)
if
str
(
base
).
lower
()
==
"compactor"
:
if
str
(
base
).
lower
()
==
"compactor"
:
return
CompactorCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
return
CompactorCompression
.
pre_rope_scoring
(
q
,
k
,
v
,
context
)
...
@@ -386,7 +417,7 @@ class CriticalAdaKVCompression(BaseCompressionMethod):
...
@@ -386,7 +417,7 @@ class CriticalAdaKVCompression(BaseCompressionMethod):
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
compression_context
=
context
.
compression_context
assert
compression_context
is
not
None
assert
compression_context
is
not
None
base
=
str
(
getattr
(
compression_context
,
"critical_ada_base_scorer"
,
"
compactor
"
)).
lower
()
base
=
str
(
getattr
(
compression_context
,
"critical_ada_base_scorer"
,
"
snapkv
"
)).
lower
()
if
base
==
"compactor"
:
if
base
==
"compactor"
:
# 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
# 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
...
@@ -447,5 +478,3 @@ class CriticalAdaKVCompression(BaseCompressionMethod):
...
@@ -447,5 +478,3 @@ class CriticalAdaKVCompression(BaseCompressionMethod):
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
)
)
module
.
_critical_ada_wo_weight
=
wo
module
.
_critical_ada_wo_weight
=
wo
vllm/kvprune/core/model_runner.py
View file @
d29c39ca
...
@@ -8,7 +8,10 @@ import torch
...
@@ -8,7 +8,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
vllm.kvprune.attention.sparse_decode_kernel
import
num_splits_heuristic
from
vllm.kvprune.attention.sparse_decode_kernel
import
num_splits_heuristic
from
vllm.kvprune.compression.compression_config
import
BatchCompressionParams
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
CompressionMethod
,
)
from
vllm.kvprune.config.constants
import
RESERVED_BATCH
from
vllm.kvprune.config.constants
import
RESERVED_BATCH
from
vllm.kvprune.config.engine_config
import
LLMConfig
,
KvpruneAttentionSchedule
from
vllm.kvprune.config.engine_config
import
LLMConfig
,
KvpruneAttentionSchedule
from
vllm.kvprune.core.memory_manager
import
KVCacheManager
from
vllm.kvprune.core.memory_manager
import
KVCacheManager
...
@@ -317,6 +320,40 @@ class ModelRunner:
...
@@ -317,6 +320,40 @@ class ModelRunner:
protected_last_tokens
=
prefill_args
.
protected_last
,
protected_last_tokens
=
prefill_args
.
protected_last
,
compression_ratio
=
prefill_args
.
compression_ratio
,
compression_ratio
=
prefill_args
.
compression_ratio
,
)
)
compression_context
.
effective_batch_tokens_to_retain
=
(
prefill_args
.
batch_tokens_to_retain
)
compression_context
.
effective_max_tokens_to_retain
=
(
prefill_args
.
max_tokens_to_retain
)
if
prefill_args
.
compression_method
==
CompressionMethod
.
CRITICALADAKV
:
hk
=
int
(
self
.
kv_manager
.
num_kv_heads
)
lens
=
prefill_args
.
context_lens
.
to
(
device
=
prefill_args
.
batch_tokens_to_retain
.
device
,
dtype
=
torch
.
int32
)
prot_first
=
torch
.
as_tensor
(
prefill_args
.
protected_first
,
device
=
prefill_args
.
batch_tokens_to_retain
.
device
,
dtype
=
torch
.
int32
,
)
prot_last
=
torch
.
as_tensor
(
prefill_args
.
protected_last
,
device
=
prefill_args
.
batch_tokens_to_retain
.
device
,
dtype
=
torch
.
int32
,
)
left_keep
=
torch
.
minimum
(
prot_first
,
lens
)
right_keep
=
torch
.
minimum
(
prot_last
,
torch
.
clamp
(
lens
-
left_keep
,
min
=
0
))
protected_pairs
=
(
left_keep
+
right_keep
)
*
hk
total_pairs
=
lens
*
hk
effective_retain
=
torch
.
minimum
(
prefill_args
.
batch_tokens_to_retain
.
to
(
dtype
=
torch
.
int32
)
+
protected_pairs
,
total_pairs
,
)
compression_context
.
effective_batch_tokens_to_retain
=
effective_retain
compression_context
.
effective_max_tokens_to_retain
=
int
(
effective_retain
.
max
().
item
()
)
cu_q_host
=
tuple
(
cu_q_host
=
tuple
(
int
(
x
)
for
x
in
prefill_args
.
cu_seqlens_q
.
detach
().
cpu
().
view
(
-
1
).
tolist
()
int
(
x
)
for
x
in
prefill_args
.
cu_seqlens_q
.
detach
().
cpu
().
view
(
-
1
).
tolist
()
)
)
...
@@ -800,5 +837,3 @@ class ModelRunner:
...
@@ -800,5 +837,3 @@ class ModelRunner:
return
None
return
None
best_sl
=
max
(
candidates
)
best_sl
=
max
(
candidates
)
return
batch_size_graphs
[
best_sl
]
return
batch_size_graphs
[
best_sl
]
vllm/kvprune/layers/attention.py
View file @
d29c39ca
...
@@ -68,16 +68,24 @@ class Attention(nn.Module):
...
@@ -68,16 +68,24 @@ class Attention(nn.Module):
):
):
compression_context
=
context
.
compression_context
compression_context
=
context
.
compression_context
assert
compression_context
is
not
None
assert
compression_context
is
not
None
retain
=
compression_context
.
effective_batch_tokens_to_retain
if
retain
is
None
:
retain
=
compression_context
.
batch_tokens_to_retain
max_retain
=
compression_context
.
effective_max_tokens_to_retain
if
not
max_retain
:
max_retain
=
compression_context
.
max_tokens_to_retain
max_retain
=
int
(
max_retain
)
maybe_execute_in_stream
(
maybe_execute_in_stream
(
extract_and_store_top_kv
,
extract_and_store_top_kv
,
scores
=
scores
,
scores
=
scores
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_k_len
=
context
.
max_seqlen_k
,
max_k_len
=
context
.
max_seqlen_k
,
top_k
=
compression_context
.
max_tokens_to
_retain
,
top_k
=
max
_retain
,
H
=
int
(
self
.
num_kv_heads
),
H
=
int
(
self
.
num_kv_heads
),
new_keys
=
k
,
new_keys
=
k
,
new_vals
=
v
,
new_vals
=
v
,
num_tokens_to_retain
=
compression_context
.
batch_tokens_to_
retain
,
num_tokens_to_retain
=
retain
,
page_table
=
self
.
page_table
,
page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
bh_lens
=
seq_lens
,
...
...
vllm/kvprune/models/llama3.py
View file @
d29c39ca
...
@@ -9,6 +9,7 @@ from torch import nn
...
@@ -9,6 +9,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.kvprune.compression
import
(
from
vllm.kvprune.compression
import
(
CompressionMethod
,
apply_postrope_compression
,
apply_postrope_compression
,
apply_prerope_compression
,
apply_prerope_compression
,
)
)
...
@@ -105,6 +106,17 @@ class LlamaAttention(nn.Module):
...
@@ -105,6 +106,17 @@ class LlamaAttention(nn.Module):
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
context
.
is_prefill
and
context
.
do_compression
:
if
context
.
is_prefill
and
context
.
do_compression
:
cc
=
context
.
compression_context
if
cc
is
not
None
and
cc
.
compression_method
==
CompressionMethod
.
CRITICALADAKV
:
wo_raw
=
self
.
o_proj
.
weight
hidden_size
,
_
=
wo_raw
.
shape
Hq
,
D
=
self
.
num_heads
,
self
.
head_dim
cc
.
wo_weight
=
(
wo_raw
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
Hq
,
D
,
hidden_size
)
.
to
(
dtype
=
torch
.
float32
)
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
...
...
vllm/kvprune/models/qwen3_moe.py
View file @
d29c39ca
...
@@ -9,6 +9,7 @@ from torch import nn
...
@@ -9,6 +9,7 @@ from torch import nn
from
transformers
import
Qwen3MoeConfig
from
transformers
import
Qwen3MoeConfig
from
vllm.kvprune.compression
import
(
from
vllm.kvprune.compression
import
(
CompressionMethod
,
apply_postrope_compression
,
apply_postrope_compression
,
apply_prerope_compression
,
apply_prerope_compression
,
)
)
...
@@ -105,6 +106,17 @@ class Qwen3MoeAttention(nn.Module):
...
@@ -105,6 +106,17 @@ class Qwen3MoeAttention(nn.Module):
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
context
.
is_prefill
and
context
.
do_compression
:
if
context
.
is_prefill
and
context
.
do_compression
:
cc
=
context
.
compression_context
if
cc
is
not
None
and
cc
.
compression_method
==
CompressionMethod
.
CRITICALADAKV
:
wo_raw
=
self
.
o_proj
.
weight
hidden_size
,
_
=
wo_raw
.
shape
Hq
,
D
=
self
.
num_heads
,
self
.
head_dim
cc
.
wo_weight
=
(
wo_raw
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
Hq
,
D
,
hidden_size
)
.
to
(
dtype
=
torch
.
float32
)
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
...
...
vllm/kvprune/utils/context.py
View file @
d29c39ca
...
@@ -16,6 +16,8 @@ class CompressionContext:
...
@@ -16,6 +16,8 @@ class CompressionContext:
compression_chunk_size
:
int
=
-
1
compression_chunk_size
:
int
=
-
1
batch_tokens_to_retain
:
torch
.
Tensor
|
None
=
None
batch_tokens_to_retain
:
torch
.
Tensor
|
None
=
None
max_tokens_to_retain
:
int
=
0
max_tokens_to_retain
:
int
=
0
effective_batch_tokens_to_retain
:
torch
.
Tensor
|
None
=
None
effective_max_tokens_to_retain
:
int
=
0
context_lens
:
List
[
int
]
|
None
=
None
context_lens
:
List
[
int
]
|
None
=
None
PHI
:
torch
.
Tensor
|
None
=
None
PHI
:
torch
.
Tensor
|
None
=
None
...
...
vllm/kvprune_legacy_save/__init__.py
0 → 100644
View file @
d29c39ca
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV-cache pruning (compactor-style) under ``vllm.kvprune``.
Use the standard :class:`~vllm.LLM` and pass ``compression=`` to :meth:`~vllm.LLM.generate`
with :class:`CompressionParams` when any prompt needs ``compression_ratio < 1``. The compactor
``LLMEngine`` + ``PagedKVCache`` shares weights with vLLM (no second checkpoint).
Subpackages (``attention``, ``kv_cache``, ``compression``, …) implement the compactor
engine.
"""
from
vllm.kvprune.compression.compression_config
import
CompressionMethod
from
vllm.kvprune.integration
import
CompressionParams
__all__
=
[
"CompressionMethod"
,
"CompressionParams"
,
]
vllm/kvprune_legacy_save/attention/__init__.py
0 → 100644
View file @
d29c39ca
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sparse attention Triton kernels (varlen prefill, decode, compile helpers)."""
from
vllm.kvprune.attention.sparse_varlen_kernel
import
causal_sparse_varlen_with_cache
__all__
=
[
"causal_sparse_varlen_with_cache"
]
vllm/kvprune_legacy_save/attention/compile_kernels.py
0 → 100644
View file @
d29c39ca
import
argparse
import
logging
import
math
import
torch
from
vllm.kvprune.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
build_mock_paged_cache_from_lengths
(
L_cache_per_b
:
torch
.
Tensor
,
HKV
:
int
,
D
:
int
,
PAGE_SIZE
:
int
,
N_LOGICAL_PAGES_MAX
:
int
,
device
,
dtype
,
):
B
=
len
(
L_cache_per_b
)
max_len
=
PAGE_SIZE
*
N_LOGICAL_PAGES_MAX
assert
(
L_cache_per_b
<=
max_len
).
all
()
seq_lens_bh
=
torch
.
empty
((
B
,
HKV
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
B
):
seq_lens_bh
[
b
,
:].
fill_
(
L_cache_per_b
[
b
])
num_phys_pages
=
B
*
HKV
*
N_LOGICAL_PAGES_MAX
CACHE_SIZE
=
num_phys_pages
*
PAGE_SIZE
K_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
V_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
page_table
=
torch
.
empty
(
(
B
,
HKV
,
N_LOGICAL_PAGES_MAX
),
device
=
device
,
dtype
=
torch
.
int32
)
# assign unique physical pages per (b, h, lp)
phys_page
=
0
for
b
in
range
(
B
):
for
h
in
range
(
HKV
):
for
lp
in
range
(
N_LOGICAL_PAGES_MAX
):
page_table
[
b
,
h
,
lp
]
=
phys_page
phys_page
+=
1
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
for
h
in
range
(
HKV
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
h
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
)
V_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
)
return
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
def
autotune_causal_sparse_varlen_with_cache
(
*
,
max_length
:
int
=
16384
,
HKV
:
int
=
8
,
HQ
:
int
=
32
,
D
:
int
=
128
,
PAGE_SIZE
:
int
=
128
,
device
:
str
=
"cuda"
,
dtype
=
torch
.
float16
,
):
"""
Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
"""
import
itertools
import
tqdm
N_LOGICAL_PAGES_MAX
=
((
max_length
+
PAGE_SIZE
-
1
)
//
PAGE_SIZE
)
*
PAGE_SIZE
B
=
4
# D must be a power of two (kernel requirement).
assert
(
D
&
(
D
-
1
))
==
0
lengths_to_sweep
=
[
0
,
256
]
i
=
9
while
(
v
:
=
(
1
<<
i
))
<
max_length
:
lengths_to_sweep
.
append
(
v
)
i
+=
1
combos
=
list
(
itertools
.
product
(
lengths_to_sweep
,
repeat
=
2
))
logger
.
info
(
"tuning kernels. this may take a few minutes, "
"but only needs to be run once per LLMConfig"
)
for
cache_l
,
append_l
in
tqdm
.
tqdm
(
combos
):
if
cache_l
+
append_l
==
0
:
continue
L_cache_per_b
=
torch
.
tensor
(
[
cache_l
]
*
B
,
device
=
device
,
dtype
=
torch
.
int32
,
)
assert
(
L_cache_per_b
<=
PAGE_SIZE
*
N_LOGICAL_PAGES_MAX
).
all
()
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
=
(
build_mock_paged_cache_from_lengths
(
L_cache_per_b
=
L_cache_per_b
,
HKV
=
HKV
,
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
device
=
device
,
dtype
=
dtype
,
)
)
L_app_list
=
[
append_l
]
*
B
cu
=
[
0
]
for
L
in
L_app_list
:
cu
.
append
(
cu
[
-
1
]
+
L
)
cu_seqlens_qk
=
torch
.
tensor
(
cu
,
dtype
=
torch
.
int32
,
device
=
device
)
N
=
int
(
cu_seqlens_qk
[
-
1
].
item
())
max_seqlen_q
=
int
((
cu_seqlens_qk
[
1
:]
-
cu_seqlens_qk
[:
-
1
]).
max
().
item
())
max_seqlen_k
=
seq_lens_bh
.
max
().
item
()
q_raw
=
torch
.
randn
(
N
,
HQ
,
D
,
device
=
device
,
dtype
=
dtype
)
k_append_raw
=
torch
.
randn
(
N
,
HKV
,
D
,
device
=
device
,
dtype
=
dtype
)
v_append_raw
=
torch
.
randn
(
N
,
HKV
,
D
,
device
=
device
,
dtype
=
dtype
)
# Identity batch mapping (local batch index == global)
batch_mapping
=
torch
.
arange
(
B
,
device
=
device
,
dtype
=
torch
.
int32
)
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
causal_sparse_varlen_with_cache
(
q
=
q_raw
,
k_cache
=
K_cache
,
v_cache
=
V_cache
,
k
=
k_append_raw
,
v
=
v_append_raw
,
seq_lens_bh
=
seq_lens_bh
,
global_page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
cu_seqlens_qk
,
HKV
=
HKV
,
PAGE_SIZE
=
PAGE_SIZE
,
sm_scale
=
sm_scale
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k_cache
=
max_seqlen_k
,
)
def
_parse_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotune Triton kernels. "
"Results are cached, so this should only need to be run once per configuration."
"This script doesn't need to be run, as the kernels will be autotuned at runtime"
"if no cached autotuning data exists. Running this before hand will prevent run-time"
"autotuning, which will accelerate compactor-vllm at inference time."
)
parser
.
add_argument
(
"--max-length"
,
type
=
int
,
default
=
16384
,
help
=
"Maximum total sequence length to consider."
,
)
parser
.
add_argument
(
"--HKV"
,
type
=
int
,
default
=
8
,
help
=
"Number of KV heads."
,
)
parser
.
add_argument
(
"--HQ"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads."
,
)
parser
.
add_argument
(
"--D"
,
type
=
int
,
default
=
128
,
help
=
"Per-head hidden dimension (must be power of 2)."
,
)
parser
.
add_argument
(
"--page-size"
,
type
=
int
,
default
=
128
,
help
=
"Page size (tokens per physical page)."
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
help
=
"Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu')."
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}."
,
)
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"INFO"
,
choices
=
[
"CRITICAL"
,
"ERROR"
,
"WARNING"
,
"INFO"
,
"DEBUG"
],
help
=
"Logging level."
,
)
return
parser
.
parse_args
()
def
_resolve_dtype
(
dtype_str
:
str
):
s
=
dtype_str
.
lower
()
if
s
in
(
"float16"
,
"fp16"
,
"half"
):
return
torch
.
float16
if
s
in
(
"bfloat16"
,
"bf16"
):
return
torch
.
bfloat16
if
s
in
(
"float32"
,
"fp32"
):
return
torch
.
float32
raise
ValueError
(
f
"Unsupported dtype:
{
dtype_str
}
"
)
def
main
():
args
=
_parse_args
()
logging
.
basicConfig
(
level
=
getattr
(
logging
,
args
.
log_level
.
upper
()),
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
,
)
dtype
=
_resolve_dtype
(
args
.
dtype
)
logger
.
info
(
"Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
"device=%s, dtype=%s"
,
args
.
max_length
,
args
.
HKV
,
args
.
HQ
,
args
.
D
,
args
.
page_size
,
args
.
device
,
dtype
,
)
autotune_causal_sparse_varlen_with_cache
(
max_length
=
args
.
max_length
,
HKV
=
args
.
HKV
,
HQ
=
args
.
HQ
,
D
=
args
.
D
,
PAGE_SIZE
=
args
.
page_size
,
device
=
args
.
device
,
dtype
=
dtype
,
)
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s %(levelname)s: %(message)s"
,
)
main
()
vllm/kvprune_legacy_save/attention/fa_paged_bridge.py
0 → 100644
View file @
d29c39ca
# SPDX-License-Identifier: Apache-2.0
"""FlashAttention paths over compactor paged KV (materialize + FA ops).
Used when :class:`~vllm.kvprune.config.engine_config.KvpruneAttentionSchedule`
selects FlashAttention for prefill and/or decode while KV **writes** remain on
Triton (``prefill_store_*``, ``decode_store_kv``). Matches the reference checks
in ``vllm/compactor-vllm/tests/test_triton_attention.py``.
"""
from
__future__
import
annotations
import
math
from
typing
import
TYPE_CHECKING
import
torch
from
flash_attn.flash_attn_interface
import
flash_attn_func
,
flash_attn_varlen_func
if
TYPE_CHECKING
:
pass
def
materialize_kv_for_flash_prefill
(
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
L_cache_per_b
:
torch
.
Tensor
,
k_append
:
torch
.
Tensor
,
v_append
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
H_kv
:
int
,
PAGE_SIZE
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Build packed K/V for :func:`flash_attn_varlen_func` (cache prefix + append)."""
device
=
k_cache
.
device
dtype
=
k_cache
.
dtype
B
=
cu_seqlens_q
.
numel
()
-
1
N
,
H_kv_raw
,
D
=
k_append
.
shape
assert
H_kv_raw
==
H_kv
L_app
=
(
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]).
to
(
torch
.
int32
)
seqlen_k
=
L_cache_per_b
.
to
(
torch
.
int32
)
+
L_app
cu_seqlens_k
=
torch
.
empty
(
B
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
[
0
]
=
0
total_k
=
int
(
seqlen_k
.
sum
().
item
())
K_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
for
b
in
range
(
B
):
offset_k
=
int
(
cu_seqlens_k
[
b
].
item
())
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
La
=
int
(
L_app
[
b
].
item
())
q_start
=
int
(
cu_seqlens_q
[
b
].
item
())
b_true
=
int
(
batch_mapping
[
b
].
item
())
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b_true
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_total
[
offset_k
+
i
,
g
]
=
k_cache
[
idx
]
V_total
[
offset_k
+
i
,
g
]
=
v_cache
[
idx
]
for
g
in
range
(
H_kv
):
for
j
in
range
(
La
):
src
=
q_start
+
j
dst
=
offset_k
+
Lc
+
j
K_total
[
dst
,
g
]
=
k_append
[
src
,
g
]
V_total
[
dst
,
g
]
=
v_append
[
src
,
g
]
cu_seqlens_k
[
b
+
1
]
=
cu_seqlens_k
[
b
]
+
(
Lc
+
La
)
return
K_total
,
V_total
,
cu_seqlens_k
def
flash_prefill_from_paged
(
q
:
torch
.
Tensor
,
k_append
:
torch
.
Tensor
,
v_append
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
*
,
seq_lens_bh_before
:
torch
.
Tensor
,
global_page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
PAGE_SIZE
:
int
,
HKV
:
int
,
sm_scale
:
float
|
None
,
)
->
torch
.
Tensor
:
"""Prefill attention via FlashAttention-2 varlen after materializing paged KV + append."""
L_cache_per_b
=
seq_lens_bh_before
.
max
(
dim
=
1
).
values
.
to
(
torch
.
int32
)
K_total
,
V_total
,
cu_seqlens_k
=
materialize_kv_for_flash_prefill
(
k_cache
,
v_cache
,
global_page_table
,
batch_mapping
,
L_cache_per_b
,
k_append
,
v_append
,
cu_seqlens_q
,
HKV
,
PAGE_SIZE
,
)
max_seqlen_k
=
int
((
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
max
().
item
())
return
flash_attn_varlen_func
(
q
,
K_total
,
V_total
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
sm_scale
if
sm_scale
is
not
None
else
None
,
causal
=
True
,
)
def
materialize_kv_cache_for_flash_decode
(
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
L_cache_per_b
:
torch
.
Tensor
,
H_kv
:
int
,
PAGE_SIZE
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Dense ``[B, S, H_kv, D]`` cache for :func:`flash_attn_func` decode."""
device
=
k_cache
.
device
dtype
=
k_cache
.
dtype
B
=
L_cache_per_b
.
shape
[
0
]
D
=
k_cache
.
shape
[
1
]
seqlen_cache_max
=
int
(
L_cache_per_b
.
max
().
item
())
K_flash
=
torch
.
zeros
((
B
,
seqlen_cache_max
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_flash
=
torch
.
zeros_like
(
K_flash
)
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
if
Lc
==
0
:
continue
b_true
=
int
(
batch_mapping
[
b
].
item
())
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b_true
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_flash
[
b
,
i
,
g
]
=
k_cache
[
idx
]
V_flash
[
b
,
i
,
g
]
=
v_cache
[
idx
]
return
K_flash
,
V_flash
def
flash_decode_from_paged
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
*
,
seq_lens_bh
:
torch
.
Tensor
,
global_page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
PAGE_SIZE
:
int
,
HKV
:
int
,
sm_scale
:
float
|
None
,
)
->
torch
.
Tensor
:
"""Decode step via FA: ``decode_store_kv`` has already appended the new K/V row."""
L_cache_per_b
=
seq_lens_bh
.
max
(
dim
=
1
).
values
.
to
(
torch
.
int32
)
K_flash
,
V_flash
=
materialize_kv_cache_for_flash_decode
(
k_cache
,
v_cache
,
global_page_table
,
batch_mapping
,
L_cache_per_b
,
HKV
,
PAGE_SIZE
,
)
B
,
HQ
,
D
=
q
.
shape
q_b
=
q
.
unsqueeze
(
1
)
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
# One query position attends to all L keys already materialized in K/V (no causal mask).
out
=
flash_attn_func
(
q_b
,
K_flash
,
V_flash
,
softmax_scale
=
sm_scale
,
causal
=
False
,
)
return
out
.
squeeze
(
1
)
vllm/kvprune_legacy_save/attention/sparse_decode_kernel.py
0 → 100644
View file @
d29c39ca
import
functools
import
math
import
torch
import
triton
import
triton.language
as
tl
from
vllm.kvprune.utils.triton_compat
import
(
autotune
as
triton_autotune
,
maybe_set_allocator
,
)
def
head_sparse_decode_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
seq_lens_bh
:
torch
.
Tensor
,
global_page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
HKV
:
int
,
PAGE_SIZE
:
int
,
sm_scale
:
float
=
None
,
key_split
:
int
=
None
,
):
"""
Decode-time head-sparse attention over a paged KV cache.
This is a wrapper around the Triton decode kernel used during incremental
generation. For each batch, we read the cached keys
and values from a global paged KV buffer, apply causal attention with one
new query token, and return the attention output.
The KV cache is stored in a single global K/V tensor of shape
``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
(batch, kv_head, token_idx) is mapped to a physical row in the cache by:
1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
Grouped-query attention (GQA / MQA) is supported by passing more query
heads than KV heads (``HQ`` must be a multiple of ``HKV``).
Args:
:param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]`
containing the new decode tokens for each sequence in the launch batch.
:param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
backing buffer for all (batch, head) KV pages.
:param v: Global value cache of shape ``[CACHE_SIZE, D]``.
:param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
local batch index and KV head, the number of valid cached tokens
in the paged KV cache.
:param global_page_table: Tensor of shape
``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
``(true_batch_idx, kv_head, logical_page)`` to a physical page id
in the global cache.
:param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
index used by this call to the true batch row used to index
``global_page_table``.
:param HKV: Number of KV heads.
:param PAGE_SIZE: Number of tokens stored per physical KV page.
:param sm_scale: Optional scaling factor applied to the attention logits
before softmax. If ``None``, ``1 / sqrt(D)`` is used.
:param key_split: Optional number of splits along the key sequence length.
If > 1, the kernel will process the KV sequence in ``key_split``
chunks to reduce on-chip memory usage. If ``None`` or 0, a
heuristic is used.
Returns:
:return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
device and dtype as ``q``.
"""
with
torch
.
cuda
.
device
(
q
.
device
):
if
q
.
ndim
!=
3
:
assert
q
.
ndim
==
4
B
,
HQ
,
S
,
D
=
q
.
shape
assert
S
==
1
,
"head_sparse_decode_attention only supports q_len=1"
q
=
q
.
squeeze
(
-
2
)
elif
q
.
ndim
==
3
:
B
,
HQ
,
D
=
q
.
shape
CACHE_SIZE
=
k
.
shape
[
0
]
assert
PAGE_SIZE
%
32
==
0
,
"PAGE_SIZE must be divisible by 32"
GROUP_M
=
HQ
//
HKV
assert
GROUP_M
*
HKV
==
HQ
,
"HQ must be divisible by H_kv"
FP8
=
hasattr
(
torch
,
"float8_e5m2"
)
and
q
.
dtype
==
torch
.
float8_e5m2
seq_lens_bh
=
seq_lens_bh
.
to
(
torch
.
int32
)
assert
B
<=
32767
,
"too many batches"
assert
global_page_table
.
shape
[
1
]
==
HKV
assert
q
.
is_contiguous
()
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
N_LOGICAL_PAGES_MAX
=
global_page_table
.
shape
[
-
1
]
sm_scale
=
1
/
math
.
sqrt
(
D
)
if
sm_scale
is
None
else
sm_scale
if
key_split
is
None
:
# round max_seq_len to the next power of two to maximize cache hits
key_split
=
num_splits_heuristic
(
B
*
HKV
,
max_seq_len
=
1
<<
int
(
seq_lens_bh
.
max
()).
bit_length
(),
num_sms
=
torch
.
cuda
.
get_device_properties
(
q
.
device
).
multi_processor_count
,
max_splits
=
12
,
)
maybe_set_allocator
(
lambda
size
,
align
,
_
:
torch
.
empty
(
size
,
dtype
=
torch
.
int8
,
device
=
q
.
device
)
)
# stage 1 scratch
mid_o
=
torch
.
empty
((
B
,
key_split
,
HQ
,
D
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
mid_lse
=
torch
.
empty
((
B
,
key_split
,
HQ
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
# processes all queries for a KV head together
# pointers are lowercase, CONSTANTS are upper
grid1
=
(
B
,
HKV
,
key_split
)
_varkv_stage1_groupM
[
grid1
](
q
=
q
,
k
=
k
,
v
=
v
,
mid_o
=
mid_o
,
mid_lse
=
mid_lse
,
page_table_bhl
=
global_page_table
,
batch_mapping
=
batch_mapping
,
seq_lens_bh
=
seq_lens_bh
.
contiguous
(),
SM_SCALE
=
sm_scale
,
B
=
B
,
HKV
=
HKV
,
HQ
=
HQ
,
CACHE_SIZE
=
CACHE_SIZE
,
STRIDE_LBS
=
mid_lse
.
stride
(
0
),
STRIDE_LS
=
mid_lse
.
stride
(
1
),
STRIDE_LH
=
mid_lse
.
stride
(
2
),
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
D
=
D
,
KEY_SPLIT
=
key_split
,
GROUP_M
=
GROUP_M
,
DTYPE
=
tl
.
float8e5
if
FP8
else
(
tl
.
bfloat16
if
q
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
),
PAGE_SIZE
=
PAGE_SIZE
,
)
if
key_split
==
1
:
return
mid_o
.
squeeze
(
1
).
contiguous
()
# reduce partial results across splits
output
=
torch
.
empty_like
(
q
)
grid2
=
(
B
,
HQ
)
_varkv_stage2_reduce
[
grid2
](
mid_o
=
mid_o
,
mid_lse
=
mid_lse
,
output
=
output
,
STRIDE_LBS
=
mid_lse
.
stride
(
0
),
STRIDE_LS
=
mid_lse
.
stride
(
1
),
STRIDE_LH
=
mid_lse
.
stride
(
2
),
STRIDE_OBS
=
output
.
stride
(
0
),
STRIDE_OH
=
output
.
stride
(
1
),
B
=
B
,
HQ
=
HQ
,
D
=
D
,
# type: ignore
KEY_SPLIT
=
key_split
,
# type: ignore
DTYPE
=
tl
.
float8e5
if
FP8
else
(
tl
.
bfloat16
if
q
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
),
)
return
output
# similar to flash attention split heuristic
@
functools
.
lru_cache
(
maxsize
=
128
)
def
num_splits_heuristic
(
total_mblocks
:
int
,
max_seq_len
:
int
,
num_sms
:
int
,
max_splits
:
int
,
)
->
int
:
# If we nearly fill SMs already, prefer 1 split
if
total_mblocks
>=
0.8
*
num_sms
or
max_seq_len
<=
1024
:
return
1
eff
=
[]
max_eff
=
0.0
for
s
in
range
(
1
,
min
(
max_splits
,
num_sms
)
+
1
):
if
(
max_seq_len
/
s
)
<=
512
:
break
n_waves
=
float
(
total_mblocks
*
s
)
/
float
(
num_sms
)
e
=
n_waves
/
math
.
ceil
(
n_waves
)
if
n_waves
>
0
else
0.0
eff
.
append
(
e
)
max_eff
=
max
(
max_eff
,
e
)
threshold
=
0.75
*
max_eff
# if not split_min_hit else 0.9 * max_eff
for
i
,
e
in
enumerate
(
eff
,
start
=
1
):
if
e
>=
threshold
:
return
i
return
1
def
prune_invalid_configs
(
configs
,
_
,
**
kwargs
):
PAGE_SIZE
=
kwargs
[
"PAGE_SIZE"
]
return
[
conf
for
conf
in
configs
if
conf
.
kwargs
.
get
(
"BLOCK_N"
,
0
)
<=
PAGE_SIZE
]
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_N"
:
BLOCK_N
,
"MIN_BLOCK_KV"
:
MIN_BLOCK_KV
,
"WARPSPEC"
:
ws
},
num_warps
=
w
,
num_stages
=
s
,
)
for
BLOCK_N
in
[
32
,
64
,
128
]
for
MIN_BLOCK_KV
in
[
8
]
for
s
in
[
2
,
3
,
4
]
for
w
in
[
4
,
8
]
for
ws
in
[
True
,
False
]
],
key
=
[
"HKV"
,
"GROUP_M"
,
"D"
,
"PAGE_SIZE"
,
# "B"
],
cache_results
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
prune_invalid_configs
},
)
@
triton
.
jit
def
_varkv_stage1_groupM
(
q
,
# [B, HQ, D] contiguous
k
,
# GLOBAL cache: [CACHE_SIZE, D], contiguous
v
,
# GLOBAL cache: [CACHE_SIZE, D], contiguous
mid_o
,
mid_lse
,
page_table_bhl
,
# int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
batch_mapping
,
# int32 [B] maps local pid_b -> true batch index
seq_lens_bh
,
# int32 [B*H_kv] valid tokens per (b,h)
SM_SCALE
,
B
,
HKV
,
HQ
,
CACHE_SIZE
,
# CACHE_SIZE = N_PAGES * PAGE_SIZE
STRIDE_LBS
,
STRIDE_LS
,
STRIDE_LH
,
# constexprs
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
# page table width per (b,h)
D
:
tl
.
constexpr
,
KEY_SPLIT
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
DTYPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
# batch
pid_kvh
=
tl
.
program_id
(
1
)
# kv head
pid_s
=
tl
.
program_id
(
2
)
# split
# valid length L for this (b,h)
bh_stride
=
HKV
L
=
tl
.
load
(
seq_lens_bh
+
pid_b
*
bh_stride
+
pid_kvh
)
if
L
==
0
:
return
tl
.
assume
(
L
>
0
)
# split sizing on logical token axis [0..L)
base
=
tl
.
cdiv
(
L
,
KEY_SPLIT
)
per_split_len
=
tl
.
cdiv
(
base
,
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
split_start
=
pid_s
*
per_split_len
split_end
=
tl
.
minimum
(
split_start
+
per_split_len
,
L
)
# query heads mapped to this kv head
base_qh
=
pid_kvh
*
GROUP_M
GROUP_M_PAD
:
tl
.
constexpr
=
16
if
GROUP_M
<
16
else
GROUP_M
offs_m
=
tl
.
arange
(
0
,
GROUP_M_PAD
)
mask_m
=
offs_m
<
GROUP_M
offs_d
=
tl
.
arange
(
0
,
D
)
# load Q tile [M, D]
q_ptrs
=
q
+
(
pid_b
*
HQ
+
base_qh
+
offs_m
)[:,
None
]
*
D
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
q_ptrs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
).
to
(
DTYPE
)
# [M, D]
# streaming softmax state per query
e_max
=
tl
.
zeros
([
GROUP_M_PAD
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
GROUP_M_PAD
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
GROUP_M_PAD
,
D
],
dtype
=
tl
.
float32
)
if
split_end
>
split_start
:
# logical pages covering [split_start, split_end)
lp0
=
split_start
//
PAGE_SIZE
lp1
=
tl
.
cdiv
(
split_end
,
PAGE_SIZE
)
# exclusive
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
tl
.
assume
(
mapped_b
>=
0
)
# page table base for this (b,h)
pt_stride
=
N_LOGICAL_PAGES_MAX
pt_base
=
(
mapped_b
*
HKV
+
pid_kvh
)
*
pt_stride
for
lp
in
tl
.
range
(
lp0
,
lp1
):
phys
=
tl
.
load
(
page_table_bhl
+
pt_base
+
lp
,
cache_modifier
=
".cg"
)
# physical page id
# bounds within the logical page
local_start
=
tl
.
where
(
lp
==
lp0
,
split_start
-
lp
*
PAGE_SIZE
,
0
)
local_end
=
tl
.
where
(
lp
==
(
lp1
-
1
),
split_end
-
lp
*
PAGE_SIZE
,
PAGE_SIZE
)
page_base
=
phys
*
PAGE_SIZE
page_base
=
tl
.
multiple_of
(
page_base
,
BLOCK_N
)
for
s
in
tl
.
range
(
local_start
,
local_end
,
BLOCK_N
):
s
=
tl
.
multiple_of
(
s
,
MIN_BLOCK_KV
)
offs_bn
=
tl
.
arange
(
0
,
BLOCK_N
)
key_idx
=
page_base
+
s
+
offs_bn
k_ptrs
=
k
+
key_idx
[:,
None
]
*
D
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
(
key_idx
<
CACHE_SIZE
)[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k_blk
.
T
)
*
SM_SCALE
# [M, BN]
offs_n
=
s
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
local_end
qk
=
tl
.
where
(
mask_n
[
None
,
:],
qk
,
-
float
(
"inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
# [M]
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
# [M]
acc
=
acc
*
re_scale
[:,
None
]
# [M, D]
v_ptrs
=
v
+
key_idx
[:,
None
]
*
D
+
offs_d
[
None
,
:]
v_blk
=
tl
.
load
(
v_ptrs
,
mask
=
(
key_idx
<
CACHE_SIZE
)[:,
None
],
other
=
0.0
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
# [M, BN]
acc
=
tl
.
dot
(
p
.
to
(
DTYPE
),
v_blk
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# write mid outputs [M, D] for this split
tmp
=
(
acc
/
e_sum
[:,
None
]).
to
(
DTYPE
)
row_mid
=
pid_b
*
(
KEY_SPLIT
*
HQ
)
+
pid_s
*
HQ
+
base_qh
+
offs_m
mid_ptrs
=
mid_o
+
row_mid
[:,
None
]
*
D
+
offs_d
[
None
,
:]
tl
.
store
(
mid_ptrs
,
tmp
,
mask
=
mask_m
[:,
None
])
ml_ptrs
=
(
mid_lse
+
pid_b
*
STRIDE_LBS
+
pid_s
*
STRIDE_LS
+
(
base_qh
+
offs_m
)
*
STRIDE_LH
)
safe_sum
=
tl
.
where
(
mask_m
,
e_sum
,
1.0
)
tl
.
store
(
ml_ptrs
,
e_max
+
tl
.
log
(
safe_sum
),
mask
=
mask_m
)
else
:
# empty split
zero_md
=
tl
.
zeros
([
GROUP_M_PAD
,
D
],
dtype
=
DTYPE
)
row_mid
=
pid_b
*
(
KEY_SPLIT
*
HQ
)
+
pid_s
*
HQ
+
base_qh
+
offs_m
mid_ptrs
=
mid_o
+
row_mid
[:,
None
]
*
D
+
offs_d
[
None
,
:]
tl
.
store
(
mid_ptrs
,
zero_md
,
mask
=
mask_m
[:,
None
])
ml_ptrs
=
(
mid_lse
+
pid_b
*
STRIDE_LBS
+
pid_s
*
STRIDE_LS
+
(
base_qh
+
offs_m
)
*
STRIDE_LH
)
tl
.
store
(
ml_ptrs
,
-
float
(
"inf"
),
mask
=
mask_m
)
@
triton
.
jit
def
_varkv_stage2_reduce
(
mid_o
,
mid_lse
,
output
,
STRIDE_LBS
,
STRIDE_LS
,
STRIDE_LH
,
STRIDE_OBS
,
STRIDE_OH
,
B
,
HQ
,
D
:
tl
.
constexpr
,
KEY_SPLIT
:
tl
.
constexpr
,
DTYPE
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
D
)
# across split LSE combine
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
D
],
dtype
=
tl
.
float32
)
for
s
in
tl
.
range
(
KEY_SPLIT
):
row_mid
=
pid_b
*
(
KEY_SPLIT
*
HQ
)
+
s
*
HQ
+
pid_h
tv
=
tl
.
load
(
mid_o
+
row_mid
*
D
+
offs_d
).
to
(
DTYPE
)
tl_ptr
=
mid_lse
+
pid_b
*
STRIDE_LBS
+
s
*
STRIDE_LS
+
pid_h
*
STRIDE_LH
tlogic
=
tl
.
load
(
tl_ptr
)
n_e_max
=
tl
.
maximum
(
e_max
,
tlogic
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
=
acc
*
old_scale
+
tl
.
exp
(
tlogic
-
n_e_max
)
*
tv
.
to
(
tl
.
float32
)
e_sum
=
e_sum
*
old_scale
+
tl
.
exp
(
tlogic
-
n_e_max
)
e_max
=
n_e_max
o
=
(
acc
/
e_sum
).
to
(
DTYPE
)
o_ptr
=
output
+
pid_b
*
STRIDE_OBS
+
pid_h
*
STRIDE_OH
+
offs_d
tl
.
store
(
o_ptr
,
o
)
vllm/kvprune_legacy_save/attention/sparse_varlen_kernel.py
0 → 100644
View file @
d29c39ca
import
logging
import
math
import
torch
import
triton
import
triton.language
as
tl
from
vllm.kvprune.utils.triton_compat
import
(
autotune
as
triton_autotune
,
cuda_capability_geq
,
maybe_set_allocator
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
causal_sparse_varlen_with_cache
(
q
,
k
,
v
,
k_cache
,
v_cache
,
seq_lens_bh
,
global_page_table
,
batch_mapping
,
cu_seqlens_q
,
max_seqlen_q
:
int
,
max_seqlen_k_cache
:
int
,
HKV
:
int
,
PAGE_SIZE
:
int
,
sm_scale
=
None
,
):
"""
Causal prefill attention over a paged KV cache plus a block of newly
appended tokens in a packed batch format.
This function wraps the Triton kernel
``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
a batch of variable-length sequences, where:
• Past keys/values are stored in a paged global KV cache
(``k_cache``, ``v_cache``) and indexed via ``global_page_table``.
• New tokens for this step are given as K/V blocks (``k``, ``v``)
together with a packed query block ``q``.
Grouped-query attention (GQA / MQA) is supported: ``HQ`` must be divisible
by ``HKV``.
"""
assert
q
.
ndim
==
3
,
"q should be [N, HQ, D]"
N
,
HQ
,
D
=
q
.
shape
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be power of two"
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
>
0
assert
HQ
%
HKV
==
0
,
"Number of query heads must divide number of keys heads"
H_g
=
HQ
//
HKV
# view Q as [HKV, N, QUERY_GROUP_SIZE, D]
out
=
torch
.
empty_like
(
q
)
q
=
q
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
out
=
out
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
# K_app/V_app: [N, HKV, D] -> [HKV, N, D]
k_app
=
k
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
v_app
=
v
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
cu_seqlens_q
=
cu_seqlens_q
.
to
(
dtype
=
torch
.
int32
,
device
=
q
.
device
)
seq_lens_bh
=
seq_lens_bh
.
to
(
dtype
=
torch
.
int32
,
device
=
q
.
device
)
batch_mapping
=
batch_mapping
.
to
(
dtype
=
torch
.
int16
,
device
=
q
.
device
)
N_LOGICAL_PAGES_MAX
=
global_page_table
.
shape
[
-
1
]
CACHE_SIZE
=
k_cache
.
shape
[
0
]
assert
v_cache
.
shape
[
0
]
==
CACHE_SIZE
assert
k_cache
.
shape
[
1
]
==
D
and
v_cache
.
shape
[
1
]
==
D
assert
PAGE_SIZE
>
0
and
CACHE_SIZE
%
PAGE_SIZE
==
0
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
# strides for Q [G, N, QUERY_GROUP_SIZE, D]
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
=
q
.
stride
()
STRIDE_KC
,
STRIDE_VC
=
k_cache
.
stride
(
0
),
v_cache
.
stride
(
0
)
# [G, N, D]
STRIDE_KA_G
,
STRIDE_KA_N
,
STRIDE_KA_D
=
k_app
.
stride
()
STRIDE_VA_G
,
STRIDE_VA_N
,
STRIDE_VA_D
=
v_app
.
stride
()
# OUT [G, N, QUERY_GROUP_SIZE, D]
STRIDE_OUT_G
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
STRIDE_OUT_D
=
out
.
stride
()
# launch grid
maybe_set_allocator
(
lambda
size
,
align
,
_
:
torch
.
empty
(
size
,
dtype
=
torch
.
int8
,
device
=
q
.
device
)
)
assert
STRIDE_KA_D
==
STRIDE_VA_D
==
STRIDE_Q_D
==
STRIDE_OUT_D
==
1
,
(
"final dimension must be contiguous"
)
def
grid
(
META
):
return
HKV
,
B
,
triton
.
cdiv
(
max_seqlen_q
,
META
[
"BLOCK_M"
])
# On a fresh batch, max_seqlen_k_cache==0 (no KV prefix yet). Passing
# `triton.next_power_of_2(0)` into autotune constexpr keys breaks
# kernel selection / tuning and can yield garbage outputs.
_k_max_autotune
=
max
(
int
(
max_seqlen_k_cache
),
1
)
AUTOTUNE_MAX_Q_LEN
=
triton
.
next_power_of_2
(
max_seqlen_q
)
AUTOTUNE_MAX_K_LEN
=
triton
.
next_power_of_2
(
_k_max_autotune
)
_causal_head_sparse_varlen_with_cache
[
grid
](
Q
=
q
,
K_cache
=
k_cache
,
V_cache
=
v_cache
,
K_app
=
k_app
,
V_app
=
v_app
,
cu_seqlens_qk
=
cu_seqlens_q
,
seq_lens_bh
=
seq_lens_bh
,
page_table
=
global_page_table
,
batch_mapping
=
batch_mapping
,
OUT
=
out
,
HKV
=
HKV
,
QUERY_GROUP_SIZE
=
H_g
,
PAGE_SIZE
=
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
STRIDE_Q_G
=
STRIDE_Q_G
,
STRIDE_Q_N
=
STRIDE_Q_N
,
STRIDE_Q_H
=
STRIDE_Q_H
,
STRIDE_KC
=
STRIDE_KC
,
STRIDE_VC
=
STRIDE_VC
,
STRIDE_KA_G
=
STRIDE_KA_G
,
STRIDE_KA_N
=
STRIDE_KA_N
,
STRIDE_VA_G
=
STRIDE_VA_G
,
STRIDE_VA_N
=
STRIDE_VA_N
,
STRIDE_OUT_G
=
STRIDE_OUT_G
,
STRIDE_OUT_N
=
STRIDE_OUT_N
,
STRIDE_OUT_H
=
STRIDE_OUT_H
,
sm_scale
=
sm_scale
,
D
=
D
,
AUTOTUNE_MAX_Q_LEN
=
AUTOTUNE_MAX_Q_LEN
,
AUTOTUNE_MAX_K_LEN
=
AUTOTUNE_MAX_K_LEN
,
)
return
out
.
permute
(
1
,
0
,
2
,
3
).
view
(
N
,
HQ
,
D
)
# already contiguous
autotune_configs_cc9
=
[
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
64
,
"WARPSPEC"
:
True
},
num_warps
=
16
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
64
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
False
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
4
),
]
autotune_configs_cc8
=
[
triton
.
Config
(
{
"BLOCK_N"
:
BN
,
"BLOCK_M"
:
BM
,
"WARPSPEC"
:
True
},
num_warps
=
w
,
num_stages
=
s
)
for
BN
in
[
16
,
32
]
for
BM
in
[
64
]
for
w
in
[
4
,
8
]
for
s
in
[
2
,
3
]
]
def
prune_invalid_configs
(
configs
,
_
,
**
kwargs
):
return
[
conf
for
conf
in
configs
if
not
(
conf
.
kwargs
.
get
(
"BLOCK_N"
)
==
32
and
conf
.
kwargs
.
get
(
"num_stages"
)
==
4
)
]
def
get_autotune_configs
():
if
cuda_capability_geq
(
9
,
0
):
return
autotune_configs_cc9
else
:
return
autotune_configs_cc8
@
triton_autotune
(
configs
=
get_autotune_configs
(),
key
=
[
"HKV"
,
"QUERY_GROUP_SIZE"
,
"D"
,
"PAGE_SIZE"
,
"AUTOTUNE_MAX_K_LEN"
,
"AUTOTUNE_MAX_Q_LEN"
,
],
cache_results
=
True
,
)
@
triton
.
jit
def
_causal_head_sparse_varlen_with_cache
(
Q
,
# [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
K_cache
,
V_cache
,
# [CACHE_SIZE, D]
K_app
,
V_app
,
# [HKV, N, D]
cu_seqlens_qk
,
# [B+1]
seq_lens_bh
,
# [B, HKV]
page_table
,
# [B_total, HKV, N_LOGICAL_PAGES_MAX]
batch_mapping
,
# [B], maps local b -> global batch index
OUT
,
# [HKV, N, QUERY_GROUP_SIZE, D]
#
HKV
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
,
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_KC
,
STRIDE_VC
,
STRIDE_KA_G
,
STRIDE_KA_N
,
STRIDE_VA_G
,
STRIDE_VA_N
,
STRIDE_OUT_G
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
#
D
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
AUTOTUNE_MAX_Q_LEN
:
tl
.
constexpr
,
# used for autotune key
AUTOTUNE_MAX_K_LEN
:
tl
.
constexpr
,
# used for autotune key
):
TOTAL_N_QUERIES
:
tl
.
constexpr
=
BLOCK_M
*
QUERY_GROUP_SIZE
pid_g
=
tl
.
program_id
(
0
)
# kv_head id in [0, HKV)
pid_b
=
tl
.
program_id
(
1
)
# batch id
pid_m
=
tl
.
program_id
(
2
)
# query-tile id within batch
# batch segment [qb, qe) in N
off_b
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
)
off_b1
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
+
1
)
seq_len_append
=
off_b1
-
off_b
q_start
=
off_b
+
pid_m
*
BLOCK_M
q_end
=
tl
.
minimum
(
q_start
+
BLOCK_M
,
off_b1
)
# number of queries in this tile for this batch
M
=
q_end
-
q_start
if
M
<=
0
:
return
# cached length for (b, kv_head=pid_g)
L_cache
=
tl
.
load
(
seq_lens_bh
+
pid_b
*
HKV
+
pid_g
)
# row indices flattened over [QUERY_GROUP_SIZE, M]
offs_row
=
tl
.
arange
(
0
,
TOTAL_N_QUERIES
)
row_m
=
offs_row
%
BLOCK_M
row_h
=
offs_row
//
BLOCK_M
# valid rows: only those with row_m < M
row_mask
=
row_m
<
M
# global query index per row
q_idx
=
q_start
+
row_m
offs_d
=
tl
.
arange
(
0
,
D
)
# Q tile: [TOTAL_N_QUERIES, D]
# Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
q_ptrs
=
(
Q
+
pid_g
*
STRIDE_Q_G
+
q_idx
[:,
None
]
*
STRIDE_Q_N
+
row_h
[:,
None
]
*
STRIDE_Q_H
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
e_max
=
tl
.
zeros
([
TOTAL_N_QUERIES
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
TOTAL_N_QUERIES
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
TOTAL_N_QUERIES
,
D
],
dtype
=
tl
.
float32
)
offs_block_n
=
tl
.
arange
(
0
,
BLOCK_N
)
qk_scale
=
sm_scale
*
1.44269504
# 1) attend over cachee K/V
if
L_cache
>
0
:
# map local (b) to global batch index
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
pt_base
=
(
mapped_b
*
HKV
+
pid_g
)
*
N_LOGICAL_PAGES_MAX
# iterate logical pages
num_lp
=
tl
.
cdiv
(
L_cache
,
PAGE_SIZE
)
for
lp
in
tl
.
range
(
0
,
num_lp
):
# can overflow in 32 bits so upcast
phys
=
tl
.
load
(
page_table
+
pt_base
+
lp
).
to
(
tl
.
int64
)
page_start
=
phys
*
PAGE_SIZE
# how many valid tokens in this page for this (b,g)
remain
=
L_cache
-
lp
*
PAGE_SIZE
page_len
=
tl
.
minimum
(
PAGE_SIZE
,
remain
)
# iterate over this page in BLOCK_N chunks
for
ks
in
tl
.
range
(
0
,
page_len
,
BLOCK_N
):
offs_n
=
ks
+
offs_block_n
mask_n
=
offs_n
<
page_len
key_idx
=
page_start
+
offs_n
k_ptrs
=
K_cache
+
key_idx
[:,
None
]
*
STRIDE_KC
+
offs_d
[
None
,
:]
k
=
tl
.
load
(
k_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
# [BN, D]
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
# [TOTAL_N_QUERIES, BN]
qk
=
tl
.
where
(
row_mask
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
# softmax update
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
v_ptrs
=
V_cache
+
key_idx
[:,
None
]
*
STRIDE_VC
+
offs_d
[
None
,
:]
v
=
tl
.
load
(
v_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
# [BN, D]
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# 2) attend over appended K_app/V_app (causal)
# appended tokens for batch b are in [off_b, off_b1)
# query tile is [q_start, q_end)
# for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
if
q_end
>
off_b
:
# exactly one appended token
if
seq_len_append
==
1
:
ka_ptrs
=
K_app
+
pid_g
*
STRIDE_KA_G
+
off_b
*
STRIDE_KA_N
+
offs_d
k
=
tl
.
load
(
ka_ptrs
)
# [D]
qk
=
tl
.
sum
(
q
*
k
[
None
,
:],
1
)
*
qk_scale
qk
=
tl
.
where
(
row_mask
,
qk
,
-
1.0e6
)
n_e_max
=
tl
.
maximum
(
e_max
,
qk
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
)
va_ptrs
=
V_app
+
pid_g
*
STRIDE_VA_G
+
off_b
*
STRIDE_VA_N
+
offs_d
v
=
tl
.
load
(
va_ptrs
)
# [D]
acc
=
acc
*
re_scale
[:,
None
]
+
p
[:,
None
]
*
v
[
None
,
:]
e_sum
=
e_sum
*
re_scale
+
p
else
:
# off-band: k in [off_b, q_start)
# for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
# so no causal mask needed.
off_band_start
=
off_b
off_band_end
=
q_start
if
off_band_end
>
off_band_start
:
for
ks
in
tl
.
range
(
off_band_start
,
off_band_end
,
BLOCK_N
):
offs_n
=
ks
+
offs_block_n
mask_n
=
offs_n
<
off_band_end
ka_ptrs
=
(
K_app
+
pid_g
*
STRIDE_KA_G
+
offs_n
[:,
None
]
*
STRIDE_KA_N
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
ka_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
row_mask
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
va_ptrs
=
(
V_app
+
pid_g
*
STRIDE_VA_G
+
offs_n
[:,
None
]
*
STRIDE_VA_N
+
offs_d
[
None
,
:]
)
v
=
tl
.
load
(
va_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# on-band remaining k
on_band_start
=
tl
.
maximum
(
q_start
,
off_b
)
if
on_band_start
<
q_end
:
for
ks
in
tl
.
range
(
on_band_start
,
q_end
,
BLOCK_N
):
offs_n
=
ks
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
q_end
ka_ptrs
=
(
K_app
+
pid_g
*
STRIDE_KA_G
+
offs_n
[:,
None
]
*
STRIDE_KA_N
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
ka_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
caus_mask
=
offs_n
[
None
,
:]
<=
q_idx
[:,
None
]
full_mask
=
row_mask
[:,
None
]
&
mask_n
[
None
,
:]
&
caus_mask
qk
=
tl
.
where
(
full_mask
,
qk
,
-
1.0e6
)
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
va_ptrs
=
(
V_app
+
pid_g
*
STRIDE_VA_G
+
offs_n
[:,
None
]
*
STRIDE_VA_N
+
offs_d
[
None
,
:]
)
v
=
tl
.
load
(
va_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# 3) write outputs
o
=
(
acc
/
e_sum
[:,
None
]).
to
(
q
.
dtype
)
out_ptrs
=
(
OUT
+
pid_g
*
STRIDE_OUT_G
+
q_idx
[:,
None
]
*
STRIDE_OUT_N
+
row_h
[:,
None
]
*
STRIDE_OUT_H
+
offs_d
[
None
,
:]
)
tl
.
store
(
out_ptrs
,
o
,
mask
=
row_mask
[:,
None
])
vllm/kvprune_legacy_save/benchmark/__init__.py
0 → 100644
View file @
d29c39ca
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark helpers for kv-prune / compactor kernels.
Upstream snapshot (``compactor-vllm/src/compactor_vllm/benchmark``) contained **only**
an empty ``__init__.py`` — no additional ``.py`` scripts. Those files are merged here
as-is; there is nothing else to list under that directory in upstream.
Use :data:`BENCHMARK_REGISTRY` to register microbenchmarks or CLI entrypoints you
add under ``vllm.kvprune.benchmark``.
"""
from
__future__
import
annotations
from
typing
import
Any
,
Callable
# Files copied from upstream ``compactor_vllm/benchmark/`` (relative to that dir).
UPSTREAM_BENCHMARK_FILES
:
tuple
[
str
,
...]
=
(
"__init__.py"
,)
# Optional: name -> benchmark callable or import path string (e.g. "mymod:main").
# Populated when you add real benchmarks beside this package.
BENCHMARK_REGISTRY
:
dict
[
str
,
Callable
[...,
Any
]
|
str
]
=
{}
def
list_upstream_benchmark_files
()
->
tuple
[
str
,
...]:
"""Return the list of filenames that existed in upstream ``benchmark/``."""
return
UPSTREAM_BENCHMARK_FILES
def
register_benchmark
(
name
:
str
,
target
:
Callable
[...,
Any
]
|
str
)
->
None
:
"""Register a benchmark by name (callable or ``"module:attr"`` import path)."""
BENCHMARK_REGISTRY
[
name
]
=
target
def
iter_registered_benchmarks
()
->
list
[
tuple
[
str
,
Callable
[...,
Any
]
|
str
]]:
"""Return ``(name, target)`` pairs from :data:`BENCHMARK_REGISTRY`."""
return
list
(
BENCHMARK_REGISTRY
.
items
())
__all__
=
[
"BENCHMARK_REGISTRY"
,
"UPSTREAM_BENCHMARK_FILES"
,
"iter_registered_benchmarks"
,
"list_upstream_benchmark_files"
,
"register_benchmark"
,
]
vllm/kvprune_legacy_save/compactor_porting_status.py
0 → 100644
View file @
d29c39ca
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layout notes: ``vllm/compactor-vllm/src/compactor_vllm`` (or sibling tree) →
``vllm.kvprune.<subdir>``.
The upstream tree is merged into parallel subpackages under ``vllm/kvprune/``
(``attention``, ``kv_cache``, ``compression``, ``config``, ``core``, ``layers``,
``models``, ``triton_kernels``, ``utils``, ``benchmark``). Imports use
``from vllm.kvprune.<module>.*``.
v1 integration (FlashAttention, ``gpu_model_runner``) lives in
``core.runtime``, ``core.flash_integration``, and ``compression/prefill.py``.
**Note:** filenames with hyphens under ``compression/`` are not importable as
Python modules; rename or load via ``importlib`` if needed.
**TP / embedding in vLLM workers:** upstream compactor-vllm used only
``vllm.kvprune`` ``ParallelLMHead`` + ``dist.gather``. When embedded in v1 workers,
prefer ``delegate_kvprune_embed_tokens_to_vllm`` and
``delegate_kvprune_compute_logits_to_vllm`` so token masking and logits match
``vocab_parallel_embedding`` + ``LogitsProcessor`` (garbled text often came from
TP gather / padded-vocab handling, not from the transformer body).
"""
from
__future__
import
annotations
import
pathlib
def
kvprune_root
()
->
pathlib
.
Path
:
"""Absolute path to ``vllm/kvprune``."""
return
pathlib
.
Path
(
__file__
).
resolve
().
parent
def
list_py_files
()
->
list
[
str
]:
"""Relative paths of all ``.py`` files under ``kvprune`` (excluding __pycache__)."""
root
=
kvprune_root
()
return
sorted
(
str
(
p
.
relative_to
(
root
)).
replace
(
"
\\
"
,
"/"
)
for
p
in
root
.
rglob
(
"*.py"
)
if
"__pycache__"
not
in
p
.
parts
)
def
format_layout_report
()
->
str
:
files
=
list_py_files
()
lines
=
[
"vllm.kvprune — merged compactor layout"
,
f
"python file count:
{
len
(
files
)
}
"
,
"="
*
50
,
*
files
[:
250
],
]
if
len
(
files
)
>
250
:
lines
.
append
(
f
"... and
{
len
(
files
)
-
250
}
more"
)
return
"
\n
"
.
join
(
lines
)
vllm/kvprune_legacy_save/compression/__init__.py
0 → 100644
View file @
d29c39ca
from
vllm.kvprune.compression.common
import
(
BaseCompressionMethod
,
NoCompression
,
)
from
vllm.kvprune.compression.criticalkv
import
CriticalAdaKVCompression
from
vllm.kvprune.compression.compactor
import
CompactorCompression
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
CompressionMethod
,
SequenceCompressionParams
,
)
from
vllm.kvprune.compression.snapkv
import
SnapKVCompression
COMPRESSION_REGISTRY
:
dict
[
CompressionMethod
,
type
[
BaseCompressionMethod
]]
=
{
CompressionMethod
.
CRITICALADAKV
:
CriticalAdaKVCompression
,
CompressionMethod
.
COMPACTOR
:
CompactorCompression
,
CompressionMethod
.
SNAPKV
:
SnapKVCompression
,
CompressionMethod
.
NONE
:
NoCompression
,
}
def
apply_prerope_compression
(
q
,
k
,
v
,
context
):
method
=
context
.
compression_context
.
compression_method
return
COMPRESSION_REGISTRY
[
method
].
pre_rope_scoring
(
q
,
k
,
v
,
context
=
context
)
def
apply_postrope_compression
(
q
,
k
,
v
,
prerope_scores
,
context
):
method
=
context
.
compression_context
.
compression_method
return
COMPRESSION_REGISTRY
[
method
].
post_rope_scoring
(
q
,
k
,
v
,
prerope_scores
,
context
=
context
)
__all__
=
[
"apply_prerope_compression"
,
"apply_postrope_compression"
,
"CompressionMethod"
,
"BatchCompressionParams"
,
"SequenceCompressionParams"
,
"COMPRESSION_REGISTRY"
]
vllm/kvprune_legacy_save/compression/common.py
0 → 100644
View file @
d29c39ca
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
from
vllm.kvprune.kv_cache.store_kv_cache
import
prefill_store_topk_kv
class
BaseCompressionMethod
(
ABC
):
"""
Abstract interface for KV cache compression methods.
A compression method is implemented as a pair of optional scoring phases
that run before and after rotary position embedding (RoPE) is applied:
1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
- refine / reweight the pre-RoPE scores, or
- compute potentially position-aware.
Concrete subclasses are expected to implement both
static methods and return a single tensor of scores (or ``None`` if the
phase is a no-op), which the caller can then feed into the shared
“scores → top-k indices → KV extraction” pipeline.
"""
@
staticmethod
@
abstractmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Compute per-token importance scores from pre-RoPE queries/keys.
Args:
:param q:
Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
``compactor_vllm.utils.context.Context`` object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
A tensor of scores (e.g. per-token, per-head importance values)
to be passed to ``post_rope_scoring`` or directly into the
top-k selection step. If this phase is a no-op, implementations
should return ``None``. Shape ``[total_tokens, HKV]```.
"""
pass
@
staticmethod
@
abstractmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
Optional
[
torch
.
Tensor
],
context
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Compute or refine importance scores from post-RoPE queries/keys.
This method is called after rotary embeddings have been applied. It can
optionally use both the post-RoPE Q/K and any scores produced by
``pre_rope_scoring`` to produce final scores used for token selection.
Common patterns include:
* Using ``pre_rope_scores`` as a base signal and applying a
position-aware correction.
* Only computing scores that depend on absolute or relative positions.
* Simply passing through ``pre_rope_scores`` unchanged.
Args:
:param q:
Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param pre_rope_scores:
Optional scores returned by ``pre_rope_scoring``. May be
``None`` if the pre-RoPE phase returned None.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
``compactor_vllm.utils.context.Context`` object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
Final importance scores to be consumed by the compression
pipeline (for top-k token selection). If this phase is a
no-op, implementations may return ``pre_rope_scores``. If
None is returned, no compression will be applied.
"""
pass
class
NoCompression
(
BaseCompressionMethod
):
"""
Trivial compression method that disables KV cache compression.
"""
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
return
pre_rope_scores
def
extract_and_store_top_kv
(
scores
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_k_len
:
int
,
top_k
:
int
,
H
:
int
,
new_keys
:
torch
.
Tensor
,
# [N_total, H, D]
new_vals
:
torch
.
Tensor
,
# [N_total, H, D]
num_tokens_to_retain
:
torch
.
Tensor
,
# [B] int32
page_table
:
torch
.
Tensor
,
# [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local -> true batch rows)
bh_lens
:
torch
.
Tensor
,
# [B, H] int32 (contiguous), UPDATED atomically
k_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
v_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE
:
int
,
PAD_TO_PAGE_SIZE
:
bool
=
True
,
K_TILE
:
int
=
16
,
padding
:
float
=
-
float
(
"inf"
),
):
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
indices_topk
=
scores_to_retain_indices
(
scores
,
cu_seqlens_k
=
cu_seqlens_k
,
max_k_len
=
max_k_len
,
top_k
=
top_k
,
H
=
H
,
padding
=
padding
,
)
prefill_store_topk_kv
(
new_keys
=
new_keys
,
new_vals
=
new_vals
,
indices_topk
=
indices_topk
,
num_tokens_to_retain
=
num_tokens_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
bh_lens
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
cu_seqlens_k
=
cu_seqlens_k
,
PAGE_SIZE
=
PAGE_SIZE
,
PAD_TO_PAGE_SIZE
=
PAD_TO_PAGE_SIZE
,
K_TILE
=
K_TILE
,
)
def
scores_to_retain_indices
(
scores
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_k_len
:
int
,
top_k
:
int
,
H
:
int
,
padding
:
float
=
-
float
(
"inf"
),
)
->
torch
.
Tensor
:
"""
Select global top-k token–head indices per sequence from packed scores.
This helper takes per-token, per-head scores in packed varlen form and
returns, for each batch element, the indices of the top-k (token, head)
pairs in the flattened global layout.
Inputs are assumed to follow the usual packed varlen convention:
• ``scores`` is laid out as ``[N_total, H]``, where:
``N_total = sum_b seqlen_k[b]``
and ``HKV`` is the number of KV heads.
• ``cu_seqlens_k`` is ``[B + 1]`` (int32), giving cumulative lengths
for the keys per batch:
``seqlen_k[b] = cu_seqlens_k[b + 1] - cu_seqlens_k[b]``.
• ``max_k_len`` is an upper bound on ``seqlen_k[b]`` across the batch.
The function pads each sequence to length ``max_k_len`` with ``padding``
(default: ``-inf``), flattens the per-sequence scores into shape
``[B, max_k_len * H]``, and runs a per-batch top-k. The returned indices
are shifted so that they directly index into the flattened global
score layout of shape ``[N_total * H]``:
global_index = (token_global_offset * H) + head_index
Args:
:param scores:
Tensor of shape ``[N_total, HKV]`` containing scores for each
(token, head) pair in packed varlen format.
:param cu_seqlens_k:
Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
lengths for each batch element. The total number of tokens
satisfies ``N_total = cu_seqlens_k[-1]``.
:param max_k_len:
Maximum key sequence length across the batch (i.e.
``max_b seqlen_k[b]``). Used to allocate the padded buffer.
:param top_k:
Number of (token, head) entries to retain **per batch element**.
If ``top_k > max_k_len * HKV``, it is clamped to ``max_k_len * HKV``.
:param H:
Number of key heads; must match ``scores.shape[1]``.
:param padding:
Padding value used when extending sequences shorter than
``max_k_len``. Defaults to ``-inf``, so that padded positions are
never selected in the top-k.
Returns:
:return torch.Tensor:
Tensor of shape ``[B, k_eff]`` (int64) where
``k_eff = min(top_k, max_k_len * H)``. Each entry is a global
index into the flattened score array of shape ``[N_total * H]``
(i.e. scores viewed as ``scores.view(-1)``),
"""
# idea: pad and then select top-k.
B
,
device
=
cu_seqlens_k
.
numel
()
-
1
,
scores
.
device
padded
=
torch
.
full
(
(
B
,
max_k_len
,
H
),
fill_value
=
padding
,
dtype
=
scores
.
dtype
,
device
=
device
)
for
b
in
range
(
B
):
s
,
e
=
int
(
cu_seqlens_k
[
b
]),
int
(
cu_seqlens_k
[
b
+
1
])
padded
[
b
,
:
e
-
s
,
:].
copy_
(
scores
[
s
:
e
,
:])
flat
=
padded
.
view
(
B
,
max_k_len
*
H
)
idx
=
torch
.
topk
(
flat
,
k
=
min
(
top_k
,
max_k_len
*
H
),
dim
=
1
,
largest
=
True
,
sorted
=
True
).
indices
return
idx
+
(
cu_seqlens_k
[:
-
1
]
*
H
).
unsqueeze
(
-
1
)
vllm/kvprune_legacy_save/compression/compactor.py
0 → 100644
View file @
d29c39ca
"""
Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
全局 z-score、blending 与首尾 sink pad)。
非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
"""
from
__future__
import
annotations
import
math
from
typing
import
List
,
Optional
import
torch
import
triton
import
triton.language
as
tl
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
vllm.kvprune.compression.common
import
BaseCompressionMethod
from
vllm.kvprune.utils.context
import
get_context
from
vllm.kvprune.utils.helpers
import
maybe_execute_in_stream
def
resolve_kvpress_compactor_blending
(
compression_context
)
->
float
:
"""与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
if
compression_context
is
None
:
return
0.35
b
=
getattr
(
compression_context
,
"compactor_blending"
,
None
)
if
b
is
not
None
:
return
float
(
b
)
cr
=
getattr
(
compression_context
,
"compression_ratio"
,
None
)
if
cr
is
not
None
:
return
float
(
cr
)
return
0.35
class
CompactorCompression
(
BaseCompressionMethod
):
"""与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
chunk_size
:
int
=
256
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
# Index key rows by K packed layout (matches master/peer packed buffers).
# Do not use `or` — cu_seqlens_* are tensors and `bool(tensor)` is invalid.
_cu_k
=
getattr
(
context
,
"cu_seqlens_k"
,
None
)
cu_k
=
context
.
cu_seqlens_q
if
_cu_k
is
None
else
_cu_k
ctx
=
get_context
()
host_k
=
ctx
.
cu_seqlens_k_host
if
host_k
is
None
:
host_k
=
ctx
.
cu_seqlens_q_host
return
maybe_execute_in_stream
(
kvpress_leverage_scores_packed
,
k
,
cu_k
,
compression_context
,
host_k
,
STORE_STREAM
=
None
,
)
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
blending
=
resolve_kvpress_compactor_blending
(
compression_context
)
return
maybe_execute_in_stream
(
kvpress_compactor_post_rope
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
pre_rope_scores
,
compression_context
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
blending
=
float
(
blending
),
STORE_STREAM
=
context
.
STORE_STREAM
,
)
# ---------------------------------------------------------------------------
# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
# ---------------------------------------------------------------------------
def
chol_with_jitter
(
G
:
torch
.
Tensor
,
jitter
:
float
=
0.0
,
max_tries
:
int
=
5
)
->
torch
.
Tensor
:
identity
=
torch
.
eye
(
G
.
shape
[
-
1
],
device
=
G
.
device
,
dtype
=
G
.
dtype
)
cur
=
float
(
jitter
)
for
_
in
range
(
max_tries
):
L
,
info
=
torch
.
linalg
.
cholesky_ex
(
G
+
cur
*
identity
,
upper
=
False
)
if
bool
((
info
==
0
).
all
()):
return
L
cur
=
max
(
1e-8
,
(
1e-2
if
cur
==
0.0
else
10.0
*
cur
))
raise
RuntimeError
(
f
"Cholesky failed after
{
max_tries
}
tries."
)
def
compute_leverage_scores_mid
(
key_states
:
torch
.
Tensor
,
sketch_dimension
:
int
)
->
torch
.
Tensor
:
"""
与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
返回 ``[L, H]``(未 z-score)。
维序与 kvpress 的 ``(B, H, S, D)`` 对齐:先变为 ``[1, H, L, D]``,在序列维(``dim=-2``)
上中心化,再与 ``Phi`` 为 ``(1, H, D, K)`` 的 batch 矩阵乘得到 ``[1, H, L, K]``。
"""
d
,
k
=
key_states
.
shape
[
-
1
],
sketch_dimension
device
,
dtype
=
key_states
.
device
,
key_states
.
dtype
H
=
key_states
.
shape
[
1
]
Phi
=
torch
.
randn
(
1
,
H
,
d
,
k
,
device
=
device
,
dtype
=
dtype
)
*
(
1.0
/
math
.
sqrt
(
k
))
# [L, H, d] -> [1, H, L, d],与 kvpress (B,H,S,d) 一致
X0
=
key_states
.
transpose
(
0
,
1
).
unsqueeze
(
0
).
contiguous
()
# ROCm batched GEMM is sensitive to non-contiguous strides after transpose/mean.
X
=
(
X0
-
X0
.
mean
(
dim
=-
2
,
keepdim
=
True
)).
contiguous
()
X
=
torch
.
matmul
(
X
,
Phi
).
to
(
torch
.
float32
).
contiguous
()
XT
=
X
.
transpose
(
-
2
,
-
1
).
contiguous
()
G
=
(
XT
@
X
).
contiguous
()
G_sym
=
0.5
*
(
G
+
G
.
transpose
(
-
2
,
-
1
)).
contiguous
()
# HIP/ROCm: rocBLAS TRSM (used by cholesky_solve and often by linalg.solve for
# triangular solves) can launch blocks (e.g. 16x64x1) > __launch_bounds__(256).
# Small sketch_dim k: inv(G) @ XT avoids TRSM; k is typically <= 128.
if
torch
.
version
.
hip
is
not
None
:
kk
=
G_sym
.
shape
[
-
1
]
eye
=
torch
.
eye
(
kk
,
device
=
G_sym
.
device
,
dtype
=
G_sym
.
dtype
,
requires_grad
=
False
)
G_reg
=
G_sym
+
1e-2
*
eye
inv_Xt
=
torch
.
linalg
.
inv
(
G_reg
)
@
XT
else
:
L
=
chol_with_jitter
(
G_sym
,
jitter
=
1e-2
,
max_tries
=
5
)
inv_Xt
=
torch
.
cholesky_solve
(
XT
,
L
,
upper
=
False
)
scores
=
(
X
*
inv_Xt
.
transpose
(
-
2
,
-
1
)).
sum
(
dim
=-
1
).
clamp_min
(
0
)
# [1, H, L] -> [L, H]
return
scores
.
squeeze
(
0
).
transpose
(
0
,
1
).
contiguous
()
def
kvpress_leverage_scores_packed
(
key_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
compression_ctx
,
cu_seqlens_host
:
tuple
[
int
,
...]
|
None
=
None
,
)
->
torch
.
Tensor
:
device
=
key_states
.
device
N
,
Hkv
,
_D
=
key_states
.
shape
sketch_dim
=
int
(
getattr
(
compression_ctx
,
"sketch_dimension"
,
48
))
sink_start
=
int
(
getattr
(
compression_ctx
,
"sink_size_start"
,
8
))
sink_end
=
int
(
getattr
(
compression_ctx
,
"sink_size_end"
,
4
))
if
cu_seqlens_host
is
not
None
:
bounds
=
list
(
cu_seqlens_host
)
total
=
bounds
[
-
1
]
else
:
cu_cpu
=
cu_seqlens
.
detach
().
cpu
().
view
(
-
1
)
total
=
int
(
cu_cpu
[
-
1
])
bounds
=
cu_cpu
.
tolist
()
if
total
!=
N
:
raise
RuntimeError
(
f
"kvpress_leverage_scores_packed: cu_seqlens[-1]=
{
total
}
!= key_states "
f
"num_rows=
{
N
}
(check packed prefill / TP broadcast)."
)
out
=
torch
.
zeros
(
N
,
Hkv
,
device
=
device
,
dtype
=
torch
.
float32
)
mids_flat
:
list
[
torch
.
Tensor
]
=
[]
mid_ranges
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
b
in
range
(
len
(
bounds
)
-
1
):
k_beg
=
int
(
bounds
[
b
])
k_end
=
int
(
bounds
[
b
+
1
])
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
k_mid
=
key_states
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
raw
=
compute_leverage_scores_mid
(
k_mid
,
sketch_dim
)
mids_flat
.
append
(
raw
.
reshape
(
-
1
))
mid_ranges
.
append
((
mid_start
,
mid_end
,
Hkv
))
if
not
mids_flat
:
return
out
flat
=
torch
.
cat
(
mids_flat
,
dim
=
0
)
z
=
_zscore_flat_f32_global
(
flat
)
offset
=
0
for
(
mid_start
,
mid_end
,
_Hkv
),
r
in
zip
(
mid_ranges
,
mids_flat
):
n
=
r
.
numel
()
seg
=
z
[
offset
:
offset
+
n
].
view
(
mid_end
-
mid_start
,
Hkv
)
out
[
mid_start
:
mid_end
,
:]
=
seg
offset
+=
n
return
out
# ---------------------------------------------------------------------------
# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
# ---------------------------------------------------------------------------
def
_non_causal_chunked_attn_pytorch
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""参考实现:与 kvpress 逐算子一致。"""
assert
chunk_size
>
0
and
q
.
shape
==
k
.
shape
L
,
H
,
d
=
q
.
shape
B
=
1
q
=
q
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
k
=
k
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
_B
,
H
,
S
,
_d
=
k
.
shape
S_pad
=
math
.
ceil
(
S
/
chunk_size
)
*
chunk_size
pad_len
=
S_pad
-
S
if
pad_len
>
0
:
q_padded
=
torch
.
cat
(
[
q
,
torch
.
zeros
(
B
,
H
,
pad_len
,
d
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)],
dim
=
2
)
k_padded
=
torch
.
cat
(
[
k
,
torch
.
zeros
(
B
,
H
,
pad_len
,
d
,
device
=
k
.
device
,
dtype
=
k
.
dtype
)],
dim
=
2
)
last_chunk_start
=
(
S
//
chunk_size
)
*
chunk_size
in_valid
=
torch
.
arange
(
last_chunk_start
,
S_pad
,
device
=
q
.
device
)
>=
S
query_mask
=
key_mask
=
in_valid
.
view
(
1
,
1
,
chunk_size
).
expand
(
B
,
H
,
chunk_size
)
else
:
q_padded
,
k_padded
=
q
,
k
last_chunk_start
=
((
S
-
1
)
//
chunk_size
)
*
chunk_size
in_valid
=
torch
.
arange
(
last_chunk_start
,
S_pad
,
device
=
q
.
device
)
>=
S
query_mask
=
key_mask
=
in_valid
.
view
(
1
,
1
,
chunk_size
).
expand
(
B
,
H
,
chunk_size
)
num_chunks
=
S_pad
//
chunk_size
q_chunks
=
q_padded
.
view
(
B
,
H
,
num_chunks
,
chunk_size
,
d
)
k_chunks
=
k_padded
.
view
(
B
,
H
,
num_chunks
,
chunk_size
,
d
)
dots
=
torch
.
matmul
(
q_chunks
,
k_chunks
.
transpose
(
-
2
,
-
1
))
dots
[:,
:,
-
1
].
masked_fill_
(
query_mask
.
unsqueeze
(
-
1
),
0
)
dots
[:,
:,
-
1
].
masked_fill_
(
key_mask
.
unsqueeze
(
-
2
),
-
1e-9
)
attn
=
torch
.
softmax
(
dots
.
to
(
torch
.
float32
),
dim
=-
1
)
out
=
attn
.
sum
(
dim
=-
2
).
view
(
B
,
H
,
S_pad
)[...,
:
S
]
return
out
.
squeeze
(
0
).
transpose
(
0
,
1
).
contiguous
()
@
triton
.
jit
def
_non_causal_chunk_row_kernel
(
Q_ptr
,
K_ptr
,
Out_ptr
,
stride_qh
,
stride_qs
,
stride_qd
,
stride_kh
,
stride_ks
,
stride_kd
,
stride_oh
,
stride_os
,
S
,
S_pad
,
num_chunks
,
CHUNK_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
ND
:
tl
.
constexpr
,
):
"""
每个 program:一个 head、一个 chunk、一条 query 行。
对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
"""
h
=
tl
.
program_id
(
0
)
c
=
tl
.
program_id
(
1
)
iq
=
tl
.
program_id
(
2
)
g_i
=
c
*
CHUNK_SIZE
+
iq
offs_j
=
tl
.
arange
(
0
,
CHUNK_SIZE
)
logits
=
tl
.
zeros
([
CHUNK_SIZE
],
dtype
=
tl
.
float32
)
for
db
in
range
(
ND
):
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
+
db
*
BLOCK_D
mask_d
=
offs_d
<
D
q_off
=
(
h
*
stride_qh
+
g_i
*
stride_qs
+
offs_d
*
stride_qd
)
qd
=
tl
.
load
(
Q_ptr
+
q_off
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
g_j
=
c
*
CHUNK_SIZE
+
offs_j
k_row_off
=
h
*
stride_kh
+
g_j
[:,
None
]
*
stride_ks
+
offs_d
[
None
,
:]
*
stride_kd
kj
=
tl
.
load
(
K_ptr
+
k_row_off
,
mask
=
mask_d
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
logits
+=
tl
.
sum
(
qd
[
None
,
:]
*
kj
,
axis
=
1
)
row_invalid
=
g_i
>=
S
g_j_all
=
c
*
CHUNK_SIZE
+
offs_j
col_invalid
=
g_j_all
>=
S
logits
=
tl
.
where
(
row_invalid
,
tl
.
zeros
([
CHUNK_SIZE
],
dtype
=
tl
.
float32
),
logits
)
logits
=
tl
.
where
(
row_invalid
,
logits
,
tl
.
where
(
col_invalid
,
tl
.
full
([
CHUNK_SIZE
],
-
1e-9
,
dtype
=
tl
.
float32
),
logits
),
)
m
=
tl
.
max
(
logits
)
logits
=
logits
-
m
exp_v
=
tl
.
exp
(
logits
)
denom
=
tl
.
sum
(
exp_v
)
p
=
exp_v
/
denom
out_base
=
h
*
stride_oh
+
g_j_all
*
stride_os
tl
.
atomic_add
(
Out_ptr
+
out_base
,
p
,
mask
=
g_j_all
<
S
)
def
_non_causal_chunked_attn_triton
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
assert
q
.
is_cuda
and
k
.
is_cuda
and
q
.
shape
==
k
.
shape
L
,
H
,
d
=
q
.
shape
assert
chunk_size
>
0
S_pad
=
math
.
ceil
(
L
/
chunk_size
)
*
chunk_size
pad_len
=
S_pad
-
L
if
pad_len
>
0
:
zq
=
torch
.
zeros
(
pad_len
,
H
,
d
,
device
=
q
.
device
,
dtype
=
q
.
dtype
,
requires_grad
=
False
)
zk
=
torch
.
zeros
(
pad_len
,
H
,
d
,
device
=
k
.
device
,
dtype
=
k
.
dtype
,
requires_grad
=
False
)
q
=
torch
.
cat
([
q
,
zq
],
dim
=
0
)
k
=
torch
.
cat
([
k
,
zk
],
dim
=
0
)
Q
=
q
.
transpose
(
0
,
1
).
contiguous
().
to
(
dtype
=
torch
.
float32
)
K
=
k
.
transpose
(
0
,
1
).
contiguous
().
to
(
dtype
=
torch
.
float32
)
num_chunks
=
S_pad
//
chunk_size
out_acc
=
torch
.
zeros
(
H
,
S_pad
,
device
=
q
.
device
,
dtype
=
torch
.
float32
)
S
=
int
(
L
)
grid
=
(
H
,
num_chunks
,
chunk_size
)
BLOCK_D
=
32
if
d
<=
128
else
64
ND
=
(
d
+
BLOCK_D
-
1
)
//
BLOCK_D
_non_causal_chunk_row_kernel
[
grid
](
Q
,
K
,
out_acc
,
Q
.
stride
(
0
),
Q
.
stride
(
1
),
Q
.
stride
(
2
),
K
.
stride
(
0
),
K
.
stride
(
1
),
K
.
stride
(
2
),
out_acc
.
stride
(
0
),
out_acc
.
stride
(
1
),
S
,
S_pad
,
int
(
num_chunks
),
CHUNK_SIZE
=
chunk_size
,
D
=
d
,
BLOCK_D
=
BLOCK_D
,
ND
=
ND
,
num_warps
=
4
,
)
return
out_acc
[:,
:
S
].
transpose
(
0
,
1
).
contiguous
()
def
non_causal_chunked_attn
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
if
q
.
is_cuda
and
k
.
is_cuda
:
return
_non_causal_chunked_attn_triton
(
q
,
k
,
chunk_size
)
return
_non_causal_chunked_attn_pytorch
(
q
,
k
,
chunk_size
)
# ---------------------------------------------------------------------------
# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
# ---------------------------------------------------------------------------
@
triton
.
jit
def
_mul_vnorm_avgpool3_kernel
(
A_ptr
,
V_ptr
,
OUT_ptr
,
stride_al
,
stride_ah
,
stride_vl
,
stride_vh
,
stride_vd
,
stride_ol
,
stride_oh
,
L
,
D
:
tl
.
constexpr
,
):
"""Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
l
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
offs
=
tl
.
arange
(
0
,
D
)
pos_m1
=
l
-
1
inb_m1
=
(
pos_m1
>=
0
)
&
(
pos_m1
<
L
)
ps_m1
=
tl
.
where
(
inb_m1
,
pos_m1
,
0
)
a_m1
=
tl
.
load
(
A_ptr
+
ps_m1
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_m1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v_m1
=
tl
.
load
(
V_ptr
+
ps_m1
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_m1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_m1
=
tl
.
where
(
inb_m1
,
a_m1
*
tl
.
sqrt
(
tl
.
sum
(
v_m1
*
v_m1
)),
0.0
)
inb_0
=
(
l
>=
0
)
&
(
l
<
L
)
ps0
=
tl
.
where
(
inb_0
,
l
,
0
)
a0
=
tl
.
load
(
A_ptr
+
ps0
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_0
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v0
=
tl
.
load
(
V_ptr
+
ps0
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_0
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_0
=
tl
.
where
(
inb_0
,
a0
*
tl
.
sqrt
(
tl
.
sum
(
v0
*
v0
)),
0.0
)
pos_p1
=
l
+
1
inb_p1
=
(
pos_p1
>=
0
)
&
(
pos_p1
<
L
)
ps_p1
=
tl
.
where
(
inb_p1
,
pos_p1
,
0
)
a_p1
=
tl
.
load
(
A_ptr
+
ps_p1
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_p1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v_p1
=
tl
.
load
(
V_ptr
+
ps_p1
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_p1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_p1
=
tl
.
where
(
inb_p1
,
a_p1
*
tl
.
sqrt
(
tl
.
sum
(
v_p1
*
v_p1
)),
0.0
)
out
=
(
s_m1
+
s_0
+
s_p1
)
*
(
1.0
/
3.0
)
tl
.
store
(
OUT_ptr
+
l
*
stride_ol
+
h
*
stride_oh
,
out
)
def
_mul_vnorm_avgpool3_fused
(
a
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
assert
a
.
dim
()
==
2
and
v
.
dim
()
==
3
and
a
.
shape
[
0
]
==
v
.
shape
[
0
]
and
a
.
shape
[
1
]
==
v
.
shape
[
1
]
L
,
H
,
D
=
v
.
shape
a
=
a
.
contiguous
()
v
=
v
.
contiguous
()
if
a
.
dtype
!=
torch
.
float32
:
a
=
a
.
float
()
if
out
is
None
:
out
=
torch
.
empty
((
L
,
H
),
device
=
v
.
device
,
dtype
=
torch
.
float32
)
if
L
==
0
or
H
==
0
:
return
out
grid
=
(
L
,
H
)
_mul_vnorm_avgpool3_kernel
[
grid
](
a
,
v
,
out
,
a
.
stride
(
0
),
a
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
out
.
stride
(
0
),
out
.
stride
(
1
),
L
,
D
=
D
,
num_warps
=
4
,
)
return
out
def
_maybe_mul_vnorm_avgpool3_fused
(
a
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
a
.
is_cuda
or
not
v
.
is_cuda
:
import
torch.nn.functional
as
F
s
=
a
*
v
.
norm
(
dim
=-
1
)
return
(
F
.
avg_pool1d
(
s
.
transpose
(
0
,
1
).
unsqueeze
(
0
),
kernel_size
=
3
,
padding
=
1
,
stride
=
1
)
.
squeeze
(
0
)
.
transpose
(
0
,
1
)
)
return
_mul_vnorm_avgpool3_fused
(
a
,
v
)
@
triton
.
jit
def
_zscore_elem_1d_kernel
(
X_ptr
,
OUT_ptr
,
n
,
mean
,
inv_std
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
offs
<
n
x
=
tl
.
load
(
X_ptr
+
offs
,
mask
=
mask
,
other
=
0.0
)
tl
.
store
(
OUT_ptr
+
offs
,
(
x
-
mean
)
*
inv_std
,
mask
=
mask
)
def
_zscore_flat_f32_global
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
"""
if
x
.
numel
()
==
0
:
return
x
mu
=
x
.
mean
()
sig
=
x
.
std
().
clamp_min
(
1e-6
)
inv
=
1.0
/
sig
if
not
x
.
is_cuda
:
return
(
x
-
mu
)
*
inv
x
=
x
.
contiguous
()
out
=
torch
.
empty_like
(
x
)
n
=
x
.
numel
()
BLOCK
=
1024
grid
=
(
triton
.
cdiv
(
n
,
BLOCK
),)
_zscore_elem_1d_kernel
[
grid
](
x
,
out
,
n
,
float
(
mu
.
item
()),
float
(
inv
.
item
()),
BLOCK
=
BLOCK
,
num_warps
=
4
,
)
return
out
def
_attn_scores_kvpress_middle
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
sink_start
:
int
,
sink_end
:
int
,
chunk_size
:
int
,
do_zscore
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
N
,
HQ
,
D
=
q
.
shape
Hkv
=
k
.
shape
[
1
]
G
=
HQ
//
Hkv
device
=
q
.
device
attn_out
=
torch
.
zeros
(
N
,
Hkv
,
device
=
device
,
dtype
=
torch
.
float32
)
parts
:
list
[
torch
.
Tensor
]
=
[]
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
q_m
=
q
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
k_m
=
k
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
v_m
=
v
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
# HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
k_4d
=
k_m
.
unsqueeze
(
0
).
transpose
(
1
,
2
).
contiguous
()
# [1, Hkv, Lm, D]
k_rep
=
repeat_kv
(
k_4d
,
G
)[
0
].
transpose
(
0
,
1
).
contiguous
()
# [Lm, HQ, D]
A
=
non_causal_chunked_attn
(
q_m
,
k_rep
,
chunk_size
)
Lm
,
HQa
=
A
.
shape
assert
HQa
==
HQ
A
=
A
.
view
(
Lm
,
Hkv
,
G
).
mean
(
dim
=-
1
)
scores
=
_maybe_mul_vnorm_avgpool3_fused
(
A
,
v_m
)
parts
.
append
(
scores
.
reshape
(
-
1
))
if
not
parts
:
return
attn_out
flat_a
=
torch
.
cat
(
parts
,
dim
=
0
)
if
do_zscore
:
z_a
=
_zscore_flat_f32_global
(
flat_a
)
else
:
z_a
=
flat_a
offset
=
0
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
n
=
(
mid_end
-
mid_start
)
*
Hkv
attn_out
[
mid_start
:
mid_end
,
:]
=
z_a
[
offset
:
offset
+
n
].
view
(
mid_end
-
mid_start
,
Hkv
)
offset
+=
n
return
attn_out
def
non_causal_attn_scores
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_qk
:
torch
.
Tensor
,
max_seqlen_qk
:
int
,
chunk_size
:
int
,
sm_scale
:
float
=
None
,
normalize
:
bool
=
True
,
context_lens
:
Optional
[
List
[
int
]]
=
None
,
protected_first_tokens
:
Optional
[
List
[
int
]]
=
None
,
protected_last_tokens
:
Optional
[
List
[
int
]]
=
None
,
*
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
)
->
torch
.
Tensor
:
"""
与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
"""
del
sm_scale
,
max_seqlen_qk
sink_start
,
sink_end
=
8
,
4
out
=
_attn_scores_kvpress_middle
(
q
,
k
,
v
,
cu_seqlens_qk
,
sink_start
,
sink_end
,
chunk_size
,
do_zscore
=
normalize
,
)
if
accum_scores
is
not
None
:
w
=
0.5
if
accum_blending
is
None
else
float
(
accum_blending
)
out
=
out
+
w
*
accum_scores
.
to
(
device
=
out
.
device
,
dtype
=
out
.
dtype
)
if
protected_first_tokens
is
not
None
and
protected_last_tokens
is
not
None
and
context_lens
:
start
=
0
for
first
,
last
,
Lc
in
zip
(
protected_first_tokens
,
protected_last_tokens
,
context_lens
):
out
[
start
:
start
+
int
(
first
)].
fill_
(
torch
.
inf
)
out
[
start
+
int
(
Lc
)
-
int
(
last
)
:
start
+
int
(
Lc
)].
fill_
(
torch
.
inf
)
start
+=
int
(
Lc
)
return
out
def
kvpress_compactor_post_rope
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
compression_ctx
,
max_seqlen_q
:
int
,
chunk_size
:
int
,
blending
:
float
,
)
->
torch
.
Tensor
:
del
max_seqlen_q
Hkv
=
k
.
shape
[
1
]
device
=
q
.
device
sink_start
=
int
(
getattr
(
compression_ctx
,
"sink_size_start"
,
8
))
sink_end
=
int
(
getattr
(
compression_ctx
,
"sink_size_end"
,
4
))
context_lens
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"context_lens"
,
None
)
protected_first
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"protected_first_tokens"
,
None
)
protected_last
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"protected_last_tokens"
,
None
)
attn_out
=
_attn_scores_kvpress_middle
(
q
,
k
,
v
,
cu_seqlens
,
sink_start
,
sink_end
,
chunk_size
)
lev
=
pre_rope_scores
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
blended
=
torch
.
zeros_like
(
lev
)
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
blended
[
mid_start
:
mid_end
,
:]
=
(
blending
*
lev
[
mid_start
:
mid_end
,
:]
+
attn_out
[
mid_start
:
mid_end
,
:]
)
pad_val
=
blended
.
max
()
if
not
torch
.
isfinite
(
pad_val
)
or
pad_val
==
0
:
pad_val
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
left_keep
>
0
:
blended
[
k_beg
:
mid_start
,
:]
=
pad_val
if
right_keep
>
0
:
blended
[
mid_end
:
k_end
,
:]
=
pad_val
if
protected_first
is
not
None
and
protected_last
is
not
None
and
context_lens
:
start
=
0
for
first
,
last
,
Lc
in
zip
(
protected_first
,
protected_last
,
context_lens
):
blended
[
start
:
start
+
int
(
first
)].
fill_
(
torch
.
inf
)
blended
[
start
+
int
(
Lc
)
-
int
(
last
)
:
start
+
int
(
Lc
)].
fill_
(
torch
.
inf
)
start
+=
int
(
Lc
)
return
blended
vllm/kvprune_legacy_save/compression/compactor_origin.py
0 → 100644
View file @
d29c39ca
import
logging
import
math
from
typing
import
List
,
Optional
import
torch
import
triton
from
tqdm.contrib.logging
import
logging_redirect_tqdm
from
triton
import
language
as
tl
from
vllm.kvprune.compression.common
import
BaseCompressionMethod
from
vllm.kvprune.utils.helpers
import
maybe_execute_in_stream
from
vllm.kvprune.utils.triton_compat
import
autotune
as
triton_autotune
logger
=
logging
.
getLogger
(
__name__
)
class
CompactorCompression
(
BaseCompressionMethod
):
chunk_size
:
int
=
128
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
scores
=
maybe_execute_in_stream
(
approximate_leverage_scores
,
k
,
compression_context
.
context_lens
,
compression_context
.
PHI
,
normalize
=
True
,
chunk_size
=
compression_context
.
compression_chunk_size
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
return
scores
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
return
maybe_execute_in_stream
(
non_causal_attn_scores
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
sm_scale
=
1.0
,
normalize
=
True
,
accum_scores
=
pre_rope_scores
,
context_lens
=
compression_context
.
context_lens
,
protected_first_tokens
=
compression_context
.
protected_first_tokens
,
protected_last_tokens
=
compression_context
.
protected_last_tokens
,
accum_blending
=
0.5
,
)
def
split_into_chunks
(
xs
,
chunk_size
):
"""
Convert a list of sequence lengths into a sequence of coalesced chunk lengths.
Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
this helper produces two parallel lists:
* ``coalesced_chunks`` – lengths of contiguous segments in the
**concatenated** sequence space, where each segment corresponds either
to a full chunk of size ``chunk_size`` or to a residual "epilogue"
tail shorter than ``chunk_size``.
* ``chunks`` – the actual chunk sizes used within each original sequence.
For a length ``n``, we produce ``n // chunk_size`` entries of
``chunk_size`` (the "prologue") and at most one final entry equal to
``n % chunk_size`` (the "epilogue").
``chunks`` reflects how each input length is decomposed into
fixed-size (plus optional tail) processing blocks, while
``coalesced_chunks`` describes those same blocks after concatenating consecutive
chunks of size ``chunk_size``. together
Example:
xs = [257, 127], chunk_size = 128
coalesced_chunks = [256, 1, 127]
chunks = [128, 128, 1, 127]
Args:
:param xs:
Iterable of non-negative integers
:param chunk_size:
Target chunk size
Returns:
:return Tuple[List[int], List[int]]:
``(coalesced_chunks, chunks)`` as described above.
"""
coalesced_chunks
,
chunks
=
[],
[]
for
n
in
xs
:
nchunks
=
n
//
chunk_size
prologue
=
nchunks
*
chunk_size
epilogue
=
n
-
prologue
if
prologue
>
0
:
coalesced_chunks
.
append
(
prologue
)
chunks
.
extend
([
chunk_size
]
*
nchunks
)
if
epilogue
>
0
:
coalesced_chunks
.
append
(
epilogue
)
chunks
.
append
(
epilogue
)
return
coalesced_chunks
,
chunks
def
approximate_leverage_scores
(
key_states
:
torch
.
Tensor
,
# [N, H, D]
context_lens
:
List
[
int
],
# [B]
PHI
:
torch
.
Tensor
,
# [D, k]
regularizer
:
float
=
5e-3
,
normalize
:
bool
=
False
,
chunk_size
:
int
=
512
,
)
->
torch
.
Tensor
:
# returns [N, H]
"""
Approximate leverage scores for keys via randomized sketching.
This implements a randomized approximation to per-token leverage scores for
the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
Args:
:param key_states:
Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
all tokens across the batch, packed along the sequence dimension.
``N = sum(context_lens)``.
:param context_lens:
List of per-sequence context lengths, length ``B``.
:param PHI:
Random projection matrix of shape ``[D, k]`` used to sketch the
keys into a lower-dimensional subspace (k < D).
:param regularizer:
Small positive scalar added to the diagonal of each Gram matrix
before SVD to improve numerical stability. Defaults to ``1e-2``.
:param normalize:
If True, apply per-sequence z-score normalization to the scores
across all heads and tokens in a batch.
:param chunk_size:
Target chunk size along the sequence dimension. If > 0, the
concatenated sequence is split into chunks of at most this size
before forming Gram matrices and SVD. If ≤ 0, the entire sequence
for each context is treated as a single chunk.
Returns:
:return torch.Tensor:
Approximate leverage scores of shape ``[N, H]``, where each row
corresponds to a token and each column to a head.
"""
if
chunk_size
>
0
:
coalesced_chunk_lens
,
chunks_lens
=
split_into_chunks
(
context_lens
,
chunk_size
)
else
:
coalesced_chunk_lens
,
chunks_lens
=
context_lens
,
context_lens
# Same device as key_states (avoid bare .cuda() → wrong GPU in multi-device
# processes); int32 matches Triton zscore kernel expectations for cu_k.
chunk_lens_cuda
=
torch
.
tensor
(
[
0
]
+
chunks_lens
,
device
=
key_states
.
device
,
dtype
=
torch
.
int32
,
)
X
=
torch
.
matmul
(
key_states
.
transpose
(
0
,
1
),
PHI
)
H
,
N
,
k
=
X
.
shape
chunks
=
torch
.
split
(
X
,
coalesced_chunk_lens
,
dim
=-
2
)
gram_matrices
=
[]
for
i
,
L
in
enumerate
(
coalesced_chunk_lens
):
chunk
=
chunks
[
i
]
if
chunk_size
<=
0
or
L
%
chunk_size
!=
0
:
chunk
.
sub_
(
chunk
.
mean
(
dim
=-
2
,
keepdim
=
True
))
g
=
torch
.
matmul
(
chunk
.
transpose
(
-
1
,
-
2
),
chunk
)
# [H, k, k]
g
=
g
.
unsqueeze
(
1
)
else
:
chunk
=
chunk
.
view
(
H
,
-
1
,
chunk_size
,
k
)
# [H, num_chunks, chunk_size, k]
chunk
.
sub_
(
chunk
.
mean
(
dim
=-
2
,
keepdim
=
True
))
g
=
torch
.
matmul
(
chunk
.
transpose
(
-
1
,
-
2
),
chunk
)
# [H, num_chunks, k, k]
gram_matrices
.
append
(
g
)
G
=
torch
.
cat
(
gram_matrices
,
dim
=
1
).
to
(
torch
.
float32
)
diag
=
G
.
diagonal
(
dim1
=-
2
,
dim2
=-
1
)
diag
.
add_
(
regularizer
)
try
:
V
,
S
,
Vt
=
torch
.
linalg
.
svd
(
G
,
full_matrices
=
False
,
driver
=
"gesvda"
)
except
RuntimeError
:
try
:
diag
=
G
.
diagonal
(
dim1
=-
2
,
dim2
=-
1
)
diag
.
add_
(
regularizer
*
10
)
V
,
S
,
Vt
=
torch
.
linalg
.
svd
(
G
,
full_matrices
=
False
,
driver
=
"gesvda"
)
except
RuntimeError
:
with
logging_redirect_tqdm
():
logger
.
warning
(
"GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
"Try increasing chunk_size if this issue persists."
)
# this is over 50 times slower than using GESVDA
return
_approximate_leverage_scores_qr_fallback
(
X
=
X
,
chunks_lens
=
chunks_lens
,
chunk_lens_cuda
=
chunk_lens_cuda
,
normalize
=
normalize
,
chunk_size
=
chunk_size
,
)
SV
=
(
V
*
S
.
rsqrt
().
unsqueeze
(
-
2
)).
to
(
X
.
dtype
)
start
=
0
all_scores
=
[]
for
i
,
L
in
enumerate
(
coalesced_chunk_lens
):
chunk
=
chunks
[
i
]
if
chunk_size
<=
0
or
L
%
chunk_size
!=
0
:
num_chunks
=
1
sv
=
SV
[:,
start
]
else
:
num_chunks
=
L
//
chunk_size
chunk
=
chunk
.
view
(
H
,
-
1
,
chunk_size
,
k
)
# [H, NC, CS]
sv
=
SV
[:,
start
:
start
+
num_chunks
]
U
=
torch
.
matmul
(
chunk
,
sv
)
scores
=
(
U
*
U
).
sum
(
dim
=-
1
).
clamp_min_
(
0.0
).
view
(
H
,
-
1
)
all_scores
.
append
(
scores
.
transpose
(
-
1
,
-
2
))
start
+=
num_chunks
scores
=
torch
.
cat
(
all_scores
,
dim
=
0
)
if
normalize
:
grid
=
(
len
(
chunks_lens
),)
cu_k
=
chunk_lens_cuda
.
cumsum
(
dim
=
0
)
_zscore_per_batch_epilogue_no_window
[
grid
](
scores
,
cu_k
,
scores
.
stride
(
0
),
scores
.
stride
(
1
),
H
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_K"
:
bk
})
for
bk
in
[
32
,
64
,
128
]],
key
=
[
"HK"
],
cache_results
=
True
,
)
@
triton
.
jit
def
_zscore_per_batch_epilogue_no_window
(
OUT
,
# [Nk, Hk], float32
cu_k
,
# [B+1] int32
STRIDE_OUT_NK
,
STRIDE_OUT_HK
,
HK
:
tl
.
constexpr
,
# Hk
BLOCK_K
:
tl
.
constexpr
,
# e.g., 128
):
b
=
tl
.
program_id
(
0
)
k_beg
=
tl
.
load
(
cu_k
+
b
)
k_end
=
tl
.
load
(
cu_k
+
b
+
1
)
if
k_end
<=
k_beg
:
return
sumv
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
sumsq
=
tl
.
zeros
([],
dtype
=
tl
.
float32
)
count
=
((
k_end
-
k_beg
)
*
HK
).
to
(
tl
.
float32
)
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
sumv
+=
tl
.
sum
(
vals
,
0
)
sumsq
+=
tl
.
sum
(
vals
*
vals
,
0
)
mean
=
sumv
/
count
var
=
tl
.
maximum
(
sumsq
/
count
-
mean
*
mean
,
0.0
)
invstd
=
1.0
/
tl
.
sqrt
(
var
)
for
ks
in
tl
.
range
(
k_beg
,
k_end
,
BLOCK_K
):
nk
=
ks
+
tl
.
arange
(
0
,
BLOCK_K
)
kmask
=
nk
<
k_end
for
h
in
tl
.
range
(
0
,
HK
):
ptrs
=
OUT
+
nk
*
STRIDE_OUT_NK
+
h
*
STRIDE_OUT_HK
vals
=
tl
.
load
(
ptrs
,
mask
=
kmask
,
other
=
0.0
).
to
(
tl
.
float32
)
vals
=
(
vals
-
mean
)
*
invstd
tl
.
store
(
ptrs
,
vals
,
mask
=
kmask
)
def
_approximate_leverage_scores_qr_fallback
(
X
:
torch
.
Tensor
,
# [H, N, k], already sketched (KΦ) and centered in-place
chunks_lens
:
List
[
int
],
# [num_chunks]
chunk_lens_cuda
:
torch
.
Tensor
,
# [num_chunks + 1] (prefix base)
normalize
:
bool
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
H
,
N
,
k
=
X
.
shape
device
,
dtype
=
X
.
device
,
X
.
dtype
offsets
:
List
[
int
]
=
[]
offset
=
0
for
L
in
chunks_lens
:
offsets
.
append
(
offset
)
offset
+=
L
if
offset
!=
N
:
raise
RuntimeError
(
f
"QR fallback: sum(chunks_lens)=
{
offset
}
does not match N=
{
N
}
"
)
blocks
=
torch
.
split
(
X
,
chunks_lens
,
dim
=-
2
)
scores
=
torch
.
empty
(
N
,
H
,
device
=
device
,
dtype
=
dtype
)
if
chunk_size
>
0
:
full_indices
=
[
i
for
i
,
L
in
enumerate
(
chunks_lens
)
if
L
==
chunk_size
]
epi_indices
=
[
i
for
i
,
L
in
enumerate
(
chunks_lens
)
if
L
!=
chunk_size
]
if
full_indices
:
# stack full chunks
full_blocks
=
torch
.
stack
(
[
blocks
[
i
]
for
i
in
full_indices
],
dim
=
0
)
# [M, H, CS, k]
M
,
Hf
,
Lf
,
kf
=
full_blocks
.
shape
assert
Lf
==
chunk_size
# merge (M, H) into a single batch dim for torch.linalg.q
full_blocks_2d
=
full_blocks
.
view
(
M
*
Hf
,
Lf
,
kf
).
to
(
torch
.
float32
)
U_full
,
_
=
torch
.
linalg
.
qr
(
full_blocks_2d
,
mode
=
"reduced"
)
U_full
=
U_full
.
to
(
dtype
)
scores_full
=
(
U_full
*
U_full
).
sum
(
dim
=-
1
).
clamp_min
(
0.0
)
# [M * Hf, Lf]
scores_full
=
scores_full
.
view
(
M
,
Hf
,
Lf
).
transpose
(
-
1
,
-
2
)
# [M, H, CS]
for
m
,
chunk_idx
in
enumerate
(
full_indices
):
start
=
offsets
[
chunk_idx
]
Lc
=
chunks_lens
[
chunk_idx
]
scores
[
start
:
start
+
Lc
].
copy_
(
scores_full
[
m
])
else
:
epi_indices
=
list
(
range
(
len
(
chunks_lens
)))
for
chunk_idx
in
epi_indices
:
block
=
blocks
[
chunk_idx
]
_
,
Lc
,
_
=
block
.
shape
if
Lc
==
0
:
continue
U_epi
,
_
=
torch
.
linalg
.
qr
(
block
.
to
(
torch
.
float32
),
mode
=
"reduced"
)
scores_epi
=
(
U_epi
*
U_epi
).
sum
(
dim
=-
1
).
to
(
dtype
)
# [H, Lc]
start
=
offsets
[
chunk_idx
]
scores
[
start
:
start
+
Lc
]
=
scores_epi
.
transpose
(
0
,
1
)
# [Lc, H]
if
normalize
:
grid
=
(
len
(
chunks_lens
),)
cu_k
=
chunk_lens_cuda
.
cumsum
(
dim
=
0
)
_zscore_per_batch_epilogue_no_window
[
grid
](
scores
,
cu_k
,
scores
.
stride
(
0
),
scores
.
stride
(
1
),
H
)
return
scores
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
BM
,
"BLOCK_K"
:
BK
,
"WARPSPEC"
:
False
},
num_warps
=
w
,
num_stages
=
s
)
for
BM
in
[
64
]
for
BK
in
[
64
]
for
w
in
[
4
]
for
s
in
[
2
]
],
key
=
[
"QUERY_GROUP_SIZE"
,
"D"
,
"CHUNK_SIZE"
,
],
cache_results
=
True
,
)
@
triton
.
jit
def
_non_causal_attn_kernel
(
Q
,
K
,
V
,
accum_scores
,
cu_seqlens_qk
,
#
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
,
STRIDE_K_G
,
STRIDE_K_N
,
STRIDE_K_D
,
STRIDE_V_G
,
STRIDE_V_N
,
STRIDE_V_D
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
#
CHUNK_SIZE
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
):
TOTAL_QUERIES_PER_BLOCK
:
tl
.
constexpr
=
BLOCK_M
*
QUERY_GROUP_SIZE
INVERSE_CHUNK
:
tl
.
constexpr
=
1.0
/
CHUNK_SIZE
pid_g
=
tl
.
program_id
(
0
)
# KV head in [0, HKV)
pid_b
=
tl
.
program_id
(
1
)
# batch id
pid_m
=
tl
.
program_id
(
2
)
# chunk id within batch
off_b
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
)
off_b1
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
+
1
)
chunk_start
=
off_b
+
pid_m
*
CHUNK_SIZE
chunk_end
=
tl
.
minimum
(
chunk_start
+
CHUNK_SIZE
,
off_b1
)
M
=
chunk_end
-
chunk_start
if
M
<=
0
:
return
offs_d
=
tl
.
arange
(
0
,
D
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
# Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
offs_q
=
tl
.
arange
(
0
,
TOTAL_QUERIES_PER_BLOCK
)
row_m
=
offs_q
%
BLOCK_M
# token offset in this tile
row_h
=
offs_q
//
BLOCK_M
# query-group index
qk_scale
=
sm_scale
*
1.44269504
# convert to log2-domain
NEG_INF
=
-
1.0e9
# Iterate over query tiles within this chunk
for
qs
in
tl
.
range
(
chunk_start
,
chunk_end
,
BLOCK_M
):
# Global query indices for rows in this tile
q_idx
=
qs
+
row_m
# [TOTAL_QUERIES_PER_BLOCK]
q_mask
=
q_idx
<
chunk_end
# mask for valid rows in this tile
# Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
q_ptrs
=
(
Q
+
pid_g
*
STRIDE_Q_G
+
q_idx
[:,
None
]
*
STRIDE_Q_N
+
row_h
[:,
None
]
*
STRIDE_Q_H
+
offs_d
[
None
,
:]
*
STRIDE_Q_D
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
q_mask
[:,
None
],
other
=
0.0
)
# ---- Pass 1: per-row max and denominator over all keys in this chunk ----
row_max
=
tl
.
full
([
TOTAL_QUERIES_PER_BLOCK
],
NEG_INF
,
tl
.
float32
)
row_sum
=
tl
.
zeros
([
TOTAL_QUERIES_PER_BLOCK
],
dtype
=
tl
.
float32
)
for
ks
in
tl
.
range
(
chunk_start
,
chunk_end
,
BLOCK_K
):
k_idx
=
ks
+
offs_k
# [BLOCK_K]
k_mask
=
k_idx
<
chunk_end
# which keys are valid in this tile
k_ptrs
=
(
K
+
pid_g
*
STRIDE_K_G
+
k_idx
[:,
None
]
*
STRIDE_K_N
+
offs_d
[
None
,
:]
*
STRIDE_K_D
)
k
=
tl
.
load
(
k_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
)
# [BLOCK_K, D]
# logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
q_mask
[:,
None
]
&
k_mask
[
None
,
:],
qk
,
NEG_INF
)
cur_max
=
tl
.
max
(
qk
,
1
)
new_max
=
tl
.
maximum
(
row_max
,
cur_max
)
# rescale previous sum to new_max (base 2)
rescale
=
tl
.
math
.
exp2
(
row_max
-
new_max
)
p
=
tl
.
math
.
exp2
(
qk
-
new_max
[:,
None
])
row_sum
=
row_sum
*
rescale
+
tl
.
sum
(
p
,
1
)
row_max
=
new_max
# Avoid division by zero for inactive rows
denom
=
tl
.
where
(
q_mask
,
row_sum
,
1.0
)
for
ks
in
tl
.
range
(
chunk_start
,
chunk_end
,
BLOCK_K
):
k_idx
=
ks
+
offs_k
k_mask
=
k_idx
<
chunk_end
k_ptrs
=
(
K
+
pid_g
*
STRIDE_K_G
+
k_idx
[:,
None
]
*
STRIDE_K_N
+
offs_d
[
None
,
:]
*
STRIDE_K_D
)
k
=
tl
.
load
(
k_ptrs
,
mask
=
k_mask
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
q_mask
[:,
None
]
&
k_mask
[
None
,
:],
qk
,
NEG_INF
)
# p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
p
=
tl
.
math
.
exp2
(
qk
-
row_max
[:,
None
])
/
denom
[:,
None
]
# zero-out invalid rows / columns
p
=
tl
.
where
(
q_mask
[:,
None
],
p
,
INVERSE_CHUNK
)
# preserve attention mass in shorter chunks
contrib
=
tl
.
sum
(
p
,
0
)
# [BLOCK_K], sum over queries & query-groups
out_ptrs
=
accum_scores
+
k_idx
*
STRIDE_OUT_N
+
pid_g
*
STRIDE_OUT_H
old
=
tl
.
load
(
out_ptrs
,
mask
=
k_mask
,
other
=
0.0
)
new
=
old
+
contrib
.
to
(
old
.
dtype
)
tl
.
store
(
out_ptrs
,
new
,
mask
=
k_mask
)
def
non_causal_attn_scores
(
q
:
torch
.
Tensor
,
# [N, HQ, D]
k
:
torch
.
Tensor
,
# [N, HKV, D]
v
:
torch
.
Tensor
,
# [N, HKV, D]
cu_seqlens_qk
:
torch
.
Tensor
,
# [B + 1]
max_seqlen_qk
:
int
,
chunk_size
:
int
,
sm_scale
:
float
=
None
,
normalize
:
bool
=
True
,
context_lens
:
Optional
[
List
[
int
]]
=
None
,
protected_first_tokens
:
Optional
[
List
[
int
]]
=
None
,
protected_last_tokens
:
Optional
[
List
[
int
]]
=
None
,
*
,
accum_scores
:
torch
.
Tensor
=
None
,
# [N, HKV] (float32)
accum_blending
:
float
=
None
,
)
->
torch
.
Tensor
:
"""
:param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
:param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
:param v: Tensor of shape ``[N, H, D]`` containing values
:param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
:param max_seqlen_qk int containing the maximum sequence length
:param chunk_size: int specifying the size of the chunk to perform non-causal attention over
:param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
:param normalize: bool specifying whether to z-score normalize final attention scores
:param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
:param protected_first_tokens: List[int] specifying how many tokens should be protected at the
start of each sequence
:param protected_last_tokens: List[int] specifying how many tokens should be protected at the
end of each sequence
:param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
:param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
"""
assert
q
.
ndim
==
3
and
k
.
ndim
==
3
assert
q
.
shape
[
0
]
==
k
.
shape
[
0
]
and
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
N
,
HQ
,
D
=
q
.
shape
HKV
=
k
.
shape
[
1
]
assert
HQ
%
HKV
==
0
,
"Number of query heads must divide number of KV heads"
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of two"
B
=
cu_seqlens_qk
.
numel
()
-
1
H_g
=
HQ
//
HKV
# query-group size per KV head
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
out
=
torch
.
zeros
(
N
,
HKV
,
device
=
q
.
device
,
dtype
=
torch
.
float32
)
q
=
q
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
k
=
k
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
# v = v.view(N, HKV, D).permute(1, 0, 2)
if
cu_seqlens_qk
.
device
!=
q
.
device
:
cu_seqlens_qk
=
cu_seqlens_qk
.
to
(
device
=
q
.
device
)
cu_seqlens_qk
=
cu_seqlens_qk
.
to
(
torch
.
int32
)
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
=
q
.
stride
()
STRIDE_K_G
,
STRIDE_K_N
,
STRIDE_K_D
=
k
.
stride
()
STRIDE_V_G
,
STRIDE_V_N
,
STRIDE_V_D
=
v
.
stride
()
STRIDE_OUT_N
,
STRIDE_OUT_H
=
out
.
stride
()
assert
STRIDE_Q_D
==
1
and
STRIDE_K_D
==
1
,
"last dim must be contiguous"
def
grid
(
_
):
return
(
HKV
,
B
,
triton
.
cdiv
(
max_seqlen_qk
,
chunk_size
),
)
_non_causal_attn_kernel
[
grid
](
q
,
k
,
v
,
out
,
cu_seqlens_qk
,
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
,
STRIDE_K_G
,
STRIDE_K_N
,
STRIDE_K_D
,
STRIDE_V_G
,
STRIDE_V_N
,
STRIDE_V_D
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
CHUNK_SIZE
=
chunk_size
,
QUERY_GROUP_SIZE
=
H_g
,
D
=
D
,
)
if
normalize
:
grid
=
(
B
,)
_zscore_per_batch_epilogue_no_window
[
grid
](
out
,
cu_seqlens_qk
,
out
.
stride
(
0
),
out
.
stride
(
1
),
HKV
)
if
accum_scores
is
not
None
:
if
accum_blending
is
not
None
:
out
+=
accum_scores
*
accum_blending
else
:
out
+=
accum_scores
if
protected_first_tokens
is
not
None
or
protected_last_tokens
is
not
None
:
start
=
0
for
first
,
last
,
L
in
zip
(
protected_first_tokens
,
protected_last_tokens
,
context_lens
):
out
[
start
:
start
+
first
].
fill_
(
torch
.
inf
)
out
[
start
+
L
-
last
:
start
+
L
].
fill_
(
torch
.
inf
)
start
+=
L
return
out
vllm/kvprune_legacy_save/compression/compression_config.py
0 → 100644
View file @
d29c39ca
import
logging
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
logger
=
logging
.
getLogger
(
__name__
)
class
CompressionMethod
(
Enum
):
CRITICALADAKV
=
auto
()
COMPACTOR
=
auto
()
SNAPKV
=
auto
()
NONE
=
auto
()
# class CachingPolicy(Enum):
# CACHE_PROMPT = auto()
# DONT_CACHE = auto()
# class CompressionType(Enum):
# QUERY_AWARE = auto()
# QUERY_AGNOSTIC = auto()
@
dataclass
class
SequenceCompressionParams
:
compression_ratio
:
float
=
1.0
protected_first_tokens
:
int
=
16
protected_last_tokens
:
int
=
64
@
dataclass
class
BatchCompressionParams
:
# compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
compression_method
:
CompressionMethod
=
CompressionMethod
.
COMPACTOR
do_chunked_compression
:
bool
=
True
chunk_size
:
int
=
512
def
__post_init__
(
self
):
if
self
.
compression_method
==
CompressionMethod
.
SNAPKV
:
self
.
do_chunked_compression
=
False
logger
.
warning
(
"CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment