Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b7160c6
Commit
2b7160c6
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.0
parent
fa718036
Changes
305
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3651 additions
and
0 deletions
+3651
-0
vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
+35
-0
vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
+83
-0
vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py
.../compactor-vllm/src/compactor_vllm/utils/triton_compat.py
+61
-0
vllm/compactor-vllm/tests/test_store_kv.py
vllm/compactor-vllm/tests/test_store_kv.py
+239
-0
vllm/compactor-vllm/tests/test_triton_attention.py
vllm/compactor-vllm/tests/test_triton_attention.py
+407
-0
vllm/compactor-vllm/vllm_memory_comparison.png
vllm/compactor-vllm/vllm_memory_comparison.png
+0
-0
vllm/compactor-vllm/vllm_throughput_comparison.png
vllm/compactor-vllm/vllm_throughput_comparison.png
+0
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+80
-0
vllm/env_override.py
vllm/env_override.py
+54
-0
vllm/envs.py
vllm/envs.py
+21
-0
vllm/kvprune/__init__.py
vllm/kvprune/__init__.py
+20
-0
vllm/kvprune/attention/__init__.py
vllm/kvprune/attention/__init__.py
+7
-0
vllm/kvprune/attention/compile_kernels.py
vllm/kvprune/attention/compile_kernels.py
+261
-0
vllm/kvprune/attention/fa_paged_bridge.py
vllm/kvprune/attention/fa_paged_bridge.py
+244
-0
vllm/kvprune/attention/sparse_decode_kernel.py
vllm/kvprune/attention/sparse_decode_kernel.py
+405
-0
vllm/kvprune/attention/sparse_varlen_kernel.py
vllm/kvprune/attention/sparse_varlen_kernel.py
+600
-0
vllm/kvprune/benchmark/__init__.py
vllm/kvprune/benchmark/__init__.py
+47
-0
vllm/kvprune/compression/__init__.py
vllm/kvprune/compression/__init__.py
+41
-0
vllm/kvprune/compression/common.py
vllm/kvprune/compression/common.py
+324
-0
vllm/kvprune/compression/compactor.py
vllm/kvprune/compression/compactor.py
+722
-0
No files found.
Too many changes to show.
To preserve performance only
305 of 305+
files are displayed.
Plain diff
Email patch
vllm/compactor-vllm/src/compactor_vllm/utils/helpers.py
0 → 100644
View file @
2b7160c6
from
collections.abc
import
Callable
import
torch
def
maybe_execute_in_stream
(
fn
:
Callable
,
*
args
,
STORE_STREAM
:
torch
.
cuda
.
Stream
=
None
,
**
kwargs
):
if
STORE_STREAM
is
not
None
:
tensors
=
[
arg
for
arg
in
args
if
isinstance
(
arg
,
torch
.
Tensor
)]
tensors
+=
[
val
for
val
in
kwargs
.
values
()
if
isinstance
(
val
,
torch
.
Tensor
)]
obj
=
getattr
(
fn
,
"__self__"
,
None
)
if
isinstance
(
obj
,
torch
.
Tensor
):
tensors
.
append
(
obj
)
STORE_STREAM
.
wait_stream
(
torch
.
cuda
.
default_stream
())
# Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
# The portable API is `torch.cuda.stream(stream)`.
stream_ctx
=
(
STORE_STREAM
if
hasattr
(
STORE_STREAM
,
"__enter__"
)
else
torch
.
cuda
.
stream
(
STORE_STREAM
)
)
with
stream_ctx
:
output
=
fn
(
*
args
,
**
kwargs
)
for
t
in
tensors
:
t
.
record_stream
(
STORE_STREAM
)
if
isinstance
(
output
,
tuple
):
for
o
in
output
:
if
isinstance
(
o
,
torch
.
Tensor
):
o
.
record_stream
(
torch
.
cuda
.
default_stream
())
elif
isinstance
(
output
,
torch
.
Tensor
):
output
.
record_stream
(
torch
.
cuda
.
default_stream
())
return
output
else
:
return
fn
(
*
args
,
**
kwargs
)
vllm/compactor-vllm/src/compactor_vllm/utils/sequence.py
0 → 100644
View file @
2b7160c6
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
itertools
import
count
from
typing
import
List
from
compactor_vllm.compression.compression_config
import
SequenceCompressionParams
from
compactor_vllm.config.sampling_params
import
SamplingParams
class
SequenceStatus
(
Enum
):
WAITING
=
auto
()
RUNNING
=
auto
()
FINISHED
=
auto
()
@
dataclass
class
Sequence
:
"""
Represents a single user request / sequence being generated.
"""
_counter
=
count
()
prompt_token_ids
:
List
[
int
]
completion_token_ids
:
List
[
int
]
=
field
(
default_factory
=
list
)
sampling_params
:
SamplingParams
=
field
(
default_factory
=
SamplingParams
)
compression_params
:
SequenceCompressionParams
=
field
(
default_factory
=
SequenceCompressionParams
)
status
:
SequenceStatus
=
SequenceStatus
.
WAITING
seq_id
:
int
=
field
(
default_factory
=
lambda
:
next
(
Sequence
.
_counter
),
init
=
False
)
num_tokens_processed
:
int
=
0
@
property
def
num_prompt_tokens
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
num_generated_tokens
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
def
add_new_token
(
self
,
token_id
:
int
)
->
None
:
if
len
(
self
.
completion_token_ids
)
==
0
:
self
.
num_tokens_processed
+=
self
.
num_prompt_tokens
self
.
completion_token_ids
.
append
(
token_id
)
self
.
num_tokens_processed
+=
1
def
tokens_to_retain_per_layer
(
self
,
num_kv_heads
:
int
)
->
int
:
n
=
int
(
self
.
compression_params
.
compression_ratio
*
self
.
num_prompt_tokens
*
num_kv_heads
)
return
max
(
1
,
n
)
def
__getstate__
(
self
):
return
dict
(
prompt_token_ids
=
list
(
self
.
prompt_token_ids
),
completion_token_ids
=
list
(
self
.
completion_token_ids
),
sampling_params
=
self
.
sampling_params
,
compression_params
=
self
.
compression_params
,
status
=
self
.
status
,
seq_id
=
self
.
seq_id
,
num_tokens_processed
=
self
.
num_tokens_processed
,
)
def
__setstate__
(
self
,
state
):
self
.
prompt_token_ids
=
list
(
state
[
"prompt_token_ids"
])
self
.
completion_token_ids
=
list
(
state
[
"completion_token_ids"
])
self
.
sampling_params
=
state
[
"sampling_params"
]
self
.
compression_params
=
state
[
"compression_params"
]
self
.
status
=
state
[
"status"
]
self
.
seq_id
=
state
[
"seq_id"
]
self
.
num_tokens_processed
=
state
[
"num_tokens_processed"
]
@
property
def
prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
completion_len
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
vllm/compactor-vllm/src/compactor_vllm/utils/triton_compat.py
0 → 100644
View file @
2b7160c6
from
__future__
import
annotations
import
inspect
from
typing
import
Any
,
Callable
,
Mapping
import
torch
def
_filter_kwargs_for_callable
(
fn
:
Callable
[...,
Any
],
kwargs
:
Mapping
[
str
,
Any
]
)
->
dict
[
str
,
Any
]:
try
:
params
=
inspect
.
signature
(
fn
).
parameters
except
(
TypeError
,
ValueError
):
return
dict
(
kwargs
)
return
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
params
}
def
autotune
(
*
,
configs
,
key
,
**
kwargs
):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import
triton
filtered
=
_filter_kwargs_for_callable
(
triton
.
autotune
,
kwargs
)
return
triton
.
autotune
(
configs
=
configs
,
key
=
key
,
**
filtered
)
def
maybe_set_allocator
(
alloc_fn
:
Callable
[[
int
,
int
,
int
|
None
],
Any
])
->
bool
:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import
triton
setter
=
getattr
(
triton
,
"set_allocator"
,
None
)
if
setter
is
None
:
return
False
setter
(
alloc_fn
)
return
True
def
cuda_capability_geq
(
major
:
int
,
minor
:
int
=
0
,
device
:
int
|
None
=
None
)
->
bool
:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if
not
torch
.
cuda
.
is_available
():
return
False
if
device
is
None
:
try
:
device
=
torch
.
cuda
.
current_device
()
except
Exception
:
device
=
0
cap
=
torch
.
cuda
.
get_device_capability
(
device
)
return
cap
>=
(
major
,
minor
)
vllm/compactor-vllm/tests/test_store_kv.py
0 → 100644
View file @
2b7160c6
import
collections
import
logging
from
dataclasses
import
dataclass
from
typing
import
List
import
pytest
import
torch
import
triton
from
compactor_vllm.compression.common
import
scores_to_retain_indices
from
src.compactor_vllm.kv_cache.store_kv_cache
import
prefill_store_topk_kv
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Workload
:
name
:
str
batch_size
:
int
nk_heads
:
int
head_dim
:
int
frac
:
float
# per-sequence cached context length fractionf
page_size
:
int
cache_lens
:
List
[
int
]
# per-sequence cached context length
WORKLOADS
:
List
[
Workload
]
=
[
Workload
(
name
=
f
"batch_size=
{
BATCH
}
kv_cache_len=
{
cache_lens
}
"
f
"FRAC=
{
frac
}
HKV=
{
NK_HEADS
}
HEAD_DIM=
{
HEAD_DIM
}
"
,
batch_size
=
BATCH
,
nk_heads
=
NK_HEADS
,
head_dim
=
HEAD_DIM
,
cache_lens
=
[
cache_lens
]
*
BATCH
,
frac
=
frac
,
page_size
=
ps
,
)
for
BATCH
in
[
1
,
2
,
3
,
8
]
for
frac
in
[
0.10
,
0.20
,
0.30
,
0.40
]
for
NK_HEADS
in
[
2
,
4
,
8
]
for
HEAD_DIM
in
[
32
,
64
,
128
]
for
cache_lens
in
[
10
,
20
,
30
,
70
,
1000
]
for
ps
in
[
128
,
256
]
]
@
pytest
.
mark
.
parametrize
(
"workload"
,
WORKLOADS
,
ids
=
lambda
wl
:
wl
.
name
)
def
test_prefill_store_topk_kv
(
workload
:
Workload
):
B
=
workload
.
batch_size
H
=
workload
.
nk_heads
D
=
workload
.
head_dim
TOP_K
=
int
(
workload
.
cache_lens
[
0
]
*
workload
.
nk_heads
*
workload
.
frac
)
PAGE_SIZE
=
workload
.
page_size
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
lens
=
torch
.
tensor
(
workload
.
cache_lens
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
=
torch
.
zeros
(
B
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
[
1
:]
=
torch
.
cumsum
(
lens
,
dim
=
0
)
N_total
=
int
(
cu
[
-
1
].
item
())
keys
=
torch
.
randn
((
N_total
,
H
,
D
),
dtype
=
dtype
,
device
=
device
)
vals
=
torch
.
randn_like
(
keys
)
scores_flat
=
torch
.
randn
((
N_total
,
H
),
dtype
=
torch
.
float32
,
device
=
device
)
top_k_eff
=
max
(
0
,
min
(
TOP_K
,
int
(
lens
.
max
().
item
())
*
H
))
max_k_len
=
cu
.
diff
().
max
().
item
()
indices
=
scores_to_retain_indices
(
scores_flat
,
cu
,
max_k_len
,
top_k_eff
,
H
)
# [B, TOP_K]
LP
=
max
(
1
,
(
top_k_eff
+
PAGE_SIZE
-
1
)
//
PAGE_SIZE
)
N_LOGICAL_PAGES_MAX
=
LP
N_PAGES
=
B
*
H
*
LP
+
32
S_LARGE
=
N_PAGES
*
PAGE_SIZE
k_cache
=
torch
.
empty
((
S_LARGE
,
D
),
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty_like
(
k_cache
)
page_table
=
torch
.
empty
(
(
B
,
H
,
N_LOGICAL_PAGES_MAX
),
dtype
=
torch
.
int32
,
device
=
device
)
phys
=
0
for
b
in
range
(
B
):
for
h
in
range
(
H
):
for
lp
in
range
(
LP
):
page_table
[
b
,
h
,
lp
]
=
phys
phys
+=
1
assert
phys
<=
N_PAGES
,
"Not enough physical pages"
local_lens
=
torch
.
zeros
((
B
,
H
),
dtype
=
torch
.
int32
,
device
=
device
)
batch_mapping
=
torch
.
arange
(
B
,
dtype
=
torch
.
int32
,
device
=
device
)
num_to_retain
=
torch
.
full
((
B
,),
top_k_eff
,
dtype
=
torch
.
int32
,
device
=
device
)
prefill_store_topk_kv
(
new_keys
=
keys
,
new_vals
=
vals
,
indices_topk
=
indices
,
num_tokens_to_retain
=
num_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
local_lens
,
PAGE_SIZE
=
PAGE_SIZE
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
PAD_TO_PAGE_SIZE
=
False
,
TRITON_RESERVED_BATCH
=-
1
,
)
torch
.
cuda
.
synchronize
()
local_lens_cpu
=
local_lens
.
cpu
()
page_table_cpu
=
page_table
.
cpu
()
k_cache_cpu
=
k_cache
.
cpu
()
v_cache_cpu
=
v_cache
.
cpu
()
keys_cpu
=
keys
.
cpu
()
vals_cpu
=
vals
.
cpu
()
indices_cpu
=
indices
.
cpu
()
for
b
in
range
(
B
):
hed
=
(
indices_cpu
[
b
]
%
H
).
numpy
()
counts
=
collections
.
Counter
(
hed
.
tolist
())
for
h
in
range
(
H
):
expected
=
counts
.
get
(
h
,
0
)
# type: ignore
got
=
int
(
local_lens_cpu
[
b
,
h
].
item
())
assert
got
==
expected
,
(
f
"Length mismatch at (b=
{
b
}
, h=
{
h
}
): got
{
got
}
, expected
{
expected
}
"
)
def
rows_for_head
(
b
,
h
,
L
):
"""Return the list of cache row indices storing the first L logical positions for (b,h)."""
rows
=
[]
for
pos
in
range
(
L
):
lp
=
pos
//
PAGE_SIZE
off
=
pos
%
PAGE_SIZE
phys
=
int
(
page_table_cpu
[
b
,
h
,
lp
].
item
())
rows
.
append
(
phys
*
PAGE_SIZE
+
off
)
return
rows
for
b
in
range
(
B
):
# which tokens per head were selected for this batch?
tok
=
(
indices_cpu
[
b
]
//
H
).
numpy
()
hed
=
(
indices_cpu
[
b
]
%
H
).
numpy
()
per_head
=
collections
.
defaultdict
(
list
)
for
t
,
h
in
zip
(
tok
,
hed
):
per_head
[
int
(
h
)].
append
(
int
(
t
))
for
h
in
range
(
H
):
L
=
int
(
local_lens_cpu
[
b
,
h
].
item
())
if
L
==
0
:
continue
# expected vectors (unordered) from source
toks_h
=
per_head
.
get
(
h
,
[])
assert
len
(
toks_h
)
==
L
expK
=
keys_cpu
[
toks_h
,
h
,
:].
contiguous
().
view
(
L
,
-
1
)
expV
=
vals_cpu
[
toks_h
,
h
,
:].
contiguous
().
view
(
L
,
-
1
)
# actual vectors read back from cache rows
rows
=
rows_for_head
(
b
,
h
,
L
)
actK
=
k_cache_cpu
[
rows
,
:].
contiguous
().
view
(
L
,
-
1
)
actV
=
v_cache_cpu
[
rows
,
:].
contiguous
().
view
(
L
,
-
1
)
expK_tuples
=
[
tuple
(
row
)
for
row
in
expK
.
numpy
().
tolist
()]
actK_tuples
=
[
tuple
(
row
)
for
row
in
actK
.
numpy
().
tolist
()]
expV_tuples
=
[
tuple
(
row
)
for
row
in
expV
.
numpy
().
tolist
()]
actV_tuples
=
[
tuple
(
row
)
for
row
in
actV
.
numpy
().
tolist
()]
assert
collections
.
Counter
(
expK_tuples
)
==
collections
.
Counter
(
actK_tuples
),
f
"K content mismatch at (b=
{
b
}
, h=
{
h
}
)"
assert
collections
.
Counter
(
expV_tuples
)
==
collections
.
Counter
(
actV_tuples
),
f
"V content mismatch at (b=
{
b
}
, h=
{
h
}
)"
def
test_prefill_store_topk_kv_pad_to_page_size
():
torch
.
manual_seed
(
0
)
B
,
H
,
D
=
2
,
2
,
64
PAGE_SIZE
=
128
RETAIN
=
64
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
lens
=
torch
.
full
((
B
,),
256
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
=
torch
.
zeros
(
B
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
cu
[
1
:]
=
torch
.
cumsum
(
lens
,
dim
=
0
)
N_total
=
int
(
cu
[
-
1
].
item
())
keys
=
torch
.
randn
((
N_total
,
H
,
D
),
dtype
=
dtype
,
device
=
device
)
vals
=
torch
.
randn_like
(
keys
)
scores_flat
=
torch
.
randn
((
N_total
,
H
),
dtype
=
torch
.
float32
,
device
=
device
)
max_k_len
=
int
(
lens
.
max
().
item
())
max_sel
=
max_k_len
*
H
indices
=
scores_to_retain_indices
(
scores_flat
,
cu
,
max_k_len
,
max_sel
,
H
)
N_LOGICAL_PAGES_MAX
=
2
N_PAGES
=
B
*
H
*
N_LOGICAL_PAGES_MAX
+
32
S_LARGE
=
N_PAGES
*
PAGE_SIZE
k_cache
=
torch
.
empty
((
S_LARGE
,
D
),
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty_like
(
k_cache
)
page_table
=
torch
.
empty
(
(
B
,
H
,
N_LOGICAL_PAGES_MAX
),
dtype
=
torch
.
int32
,
device
=
device
)
phys
=
0
for
b
in
range
(
B
):
for
h
in
range
(
H
):
for
lp
in
range
(
N_LOGICAL_PAGES_MAX
):
page_table
[
b
,
h
,
lp
]
=
phys
phys
+=
1
assert
phys
<=
N_PAGES
,
"Not enough physical pages"
local_lens
=
torch
.
zeros
((
B
,
H
),
dtype
=
torch
.
int32
,
device
=
device
)
batch_mapping
=
torch
.
arange
(
B
,
dtype
=
torch
.
int32
,
device
=
device
)
num_to_retain
=
torch
.
full
((
B
,),
RETAIN
,
dtype
=
torch
.
int32
,
device
=
device
)
prefill_store_topk_kv
(
new_keys
=
keys
,
new_vals
=
vals
,
indices_topk
=
indices
,
num_tokens_to_retain
=
num_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
local_lens
,
PAGE_SIZE
=
PAGE_SIZE
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
PAD_TO_PAGE_SIZE
=
True
,
cu_seqlens_k
=
cu
,
TRITON_RESERVED_BATCH
=-
1
,
)
torch
.
cuda
.
synchronize
()
local_lens_cpu
=
local_lens
.
cpu
()
lens_cpu
=
lens
.
cpu
()
assert
(
local_lens_cpu
%
PAGE_SIZE
==
0
).
all
()
assert
(
local_lens_cpu
<=
lens_cpu
[:,
None
]).
all
()
vllm/compactor-vllm/tests/test_triton_attention.py
0 → 100644
View file @
2b7160c6
import
logging
import
math
from
dataclasses
import
dataclass
from
typing
import
List
import
pytest
import
torch
import
triton
from
flash_attn.flash_attn_interface
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
)
from
compactor_vllm.attention.sparse_decode_kernel
import
head_sparse_decode_attention
from
compactor_vllm.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Workload
:
name
:
str
batch_size
:
int
nq_heads
:
int
nk_heads
:
int
head_dim
:
int
cache_lens
:
List
[
int
]
# per-sequence cached context length
append_lens
:
List
[
int
]
# per-sequence new tokens this step (Q_app, K_app, V_app)
WORKLOADS
:
List
[
Workload
]
=
[
Workload
(
name
=
f
"batch_size=
{
BATCH
}
kv_cache_len=
{
cache_lens
}
append_len=
{
append_lens
}
"
f
"HQ=
{
NQ_HEADS
}
HKV=
{
NK_HEADS
}
HEAD_DIM=
{
HEAD_DIM
}
"
,
batch_size
=
BATCH
,
nq_heads
=
NQ_HEADS
,
nk_heads
=
NK_HEADS
,
head_dim
=
HEAD_DIM
,
cache_lens
=
[
cache_lens
]
*
BATCH
,
append_lens
=
[
append_lens
]
*
BATCH
,
)
for
BATCH
in
[
1
,
2
,
3
,
8
]
for
NQ_HEADS
in
[
32
]
for
NK_HEADS
in
[
8
]
for
HEAD_DIM
in
[
128
]
for
cache_lens
in
[
0
,
1
,
70
,
128
,
8193
]
for
append_lens
in
[
1
,
2
,
13
,
8000
]
]
WORKLOADS_DECODE
:
List
[
Workload
]
=
[
Workload
(
name
=
f
"batch_size=
{
BATCH
}
kv_cache_len=
{
cache_lens
}
"
f
"HQ=
{
NQ_HEADS
}
HKV=
{
NK_HEADS
}
HEAD_DIM=
{
HEAD_DIM
}
"
,
batch_size
=
BATCH
,
nq_heads
=
NQ_HEADS
,
nk_heads
=
NK_HEADS
,
head_dim
=
HEAD_DIM
,
cache_lens
=
[
cache_lens
]
*
BATCH
,
append_lens
=
[
1
]
*
BATCH
,
)
for
BATCH
in
[
1
,
2
,
3
,
8
]
for
NQ_HEADS
in
[
32
]
for
NK_HEADS
in
[
8
]
for
HEAD_DIM
in
[
128
]
for
cache_lens
in
[
1
,
2
,
70
,
128
,
8000
]
]
def
build_paged_cache_from_lengths
(
B
,
H_kv
,
D
,
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
,
L_cache_per_b
,
# int32 [B], per-batch cache length
device
,
dtype
,
):
"""
Construct:
- seq_lens_bh[b, h] = L_cache_per_b[b]
- page_table[b, h, lp] giving physical page ids
- K_cache, V_cache filled for valid cached tokens
Physical layout:
physical_page_id = (b * H_kv + h) * N_LOGICAL_PAGES_MAX + lp
CACHE_SIZE = num_phys_pages * PAGE_SIZE
"""
assert
L_cache_per_b
.
shape
[
0
]
==
B
max_len
=
PAGE_SIZE
*
N_LOGICAL_PAGES_MAX
assert
(
L_cache_per_b
<=
max_len
).
all
()
seq_lens_bh
=
torch
.
empty
((
B
,
H_kv
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
B
):
seq_lens_bh
[
b
,
:].
fill_
(
L_cache_per_b
[
b
])
num_phys_pages
=
B
*
H_kv
*
N_LOGICAL_PAGES_MAX
CACHE_SIZE
=
num_phys_pages
*
PAGE_SIZE
K_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
V_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
page_table
=
torch
.
empty
(
(
B
,
H_kv
,
N_LOGICAL_PAGES_MAX
),
device
=
device
,
dtype
=
torch
.
int32
)
# assign unique physical pages per (b, h, lp)
phys_page
=
0
for
b
in
range
(
B
):
for
h
in
range
(
H_kv
):
for
lp
in
range
(
N_LOGICAL_PAGES_MAX
):
page_table
[
b
,
h
,
lp
]
=
phys_page
phys_page
+=
1
# fill cached tokens
g
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
1234
)
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
for
h
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
h
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
,
generator
=
g
)
V_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
,
generator
=
g
)
return
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
def
materialize_kv_for_flash_mixed
(
K_cache
,
V_cache
,
page_table
,
L_cache_per_b
,
# [B]
k_append_raw
,
# [N, H_kv, D]
v_append_raw
,
# [N, H_kv, D]
cu_seqlens_qk
,
# [B+1]
H_kv
,
PAGE_SIZE
,
):
"""
Build (K_total, V_total, cu_seqlens_k) for flash_attn_varlen_func such that:
For each batch b:
seqlen_q[b] = L_app[b] = cu[b+1] - cu[b]
seqlen_k[b] = L_cache_per_b[b] + L_app[b]
Keys:
- first L_cache_per_b[b] positions from paged cache
- next L_app[b] positions from k_append_raw for that batch
"""
device
=
K_cache
.
device
dtype
=
K_cache
.
dtype
B
=
cu_seqlens_qk
.
numel
()
-
1
N
,
H_kv_raw
,
D
=
k_append_raw
.
shape
assert
H_kv_raw
==
H_kv
# appended lengths
L_app
=
(
cu_seqlens_qk
[
1
:]
-
cu_seqlens_qk
[:
-
1
]).
to
(
torch
.
int32
)
# [B]
seqlen_k
=
L_cache_per_b
+
L_app
# [B]
cu_seqlens_k
=
torch
.
empty
(
B
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
[
0
]
=
0
total_k
=
int
(
seqlen_k
.
sum
().
item
())
K_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
for
b
in
range
(
B
):
offset_k
=
int
(
cu_seqlens_k
[
b
].
item
())
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
La
=
int
(
L_app
[
b
].
item
())
q_start
=
int
(
cu_seqlens_qk
[
b
].
item
())
# cache segment
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_total
[
offset_k
+
i
,
g
]
=
K_cache
[
idx
]
V_total
[
offset_k
+
i
,
g
]
=
V_cache
[
idx
]
# appended segment
if
k_append_raw
.
numel
()
>
0
:
for
g
in
range
(
H_kv
):
for
j
in
range
(
La
):
src
=
q_start
+
j
dst
=
offset_k
+
Lc
+
j
K_total
[
dst
,
g
]
=
k_append_raw
[
src
,
g
]
V_total
[
dst
,
g
]
=
v_append_raw
[
src
,
g
]
cu_seqlens_k
[
b
+
1
]
=
cu_seqlens_k
[
b
]
+
(
Lc
+
La
)
return
K_total
,
V_total
,
cu_seqlens_k
@
pytest
.
mark
.
parametrize
(
"workload"
,
WORKLOADS
,
ids
=
lambda
wl
:
wl
.
name
)
def
test_causal_sparse_varlen_with_cache
(
workload
:
Workload
):
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
DEFAULT_PAGE_SIZE
=
256
N_LOGICAL_PAGES_MAX
=
256
L_cache_per_b
=
torch
.
as_tensor
(
workload
.
cache_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
=
(
build_paged_cache_from_lengths
(
B
=
workload
.
batch_size
,
H_kv
=
workload
.
nk_heads
,
D
=
workload
.
head_dim
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
L_cache_per_b
=
L_cache_per_b
,
device
=
device
,
dtype
=
dtype
,
)
)
assert
len
(
workload
.
append_lens
)
==
workload
.
batch_size
cu
=
[
0
]
for
L
in
workload
.
append_lens
:
cu
.
append
(
cu
[
-
1
]
+
L
)
cu_seqlens_qk
=
torch
.
tensor
(
cu
,
dtype
=
torch
.
int32
,
device
=
device
)
N
=
int
(
cu_seqlens_qk
[
-
1
].
item
())
q_raw
=
torch
.
randn
(
N
,
workload
.
nq_heads
,
workload
.
head_dim
,
device
=
device
,
dtype
=
dtype
)
k_append_raw
=
torch
.
randn
(
N
,
workload
.
nk_heads
,
workload
.
head_dim
,
device
=
device
,
dtype
=
dtype
)
v_append_raw
=
torch
.
randn_like
(
k_append_raw
)
batch_mapping
=
torch
.
arange
(
workload
.
batch_size
,
device
=
device
,
dtype
=
torch
.
int32
)
sm_scale
=
1.0
/
math
.
sqrt
(
workload
.
head_dim
)
K_total
,
V_total
,
cu_seqlens_k
=
materialize_kv_for_flash_mixed
(
K_cache
=
K_cache
,
V_cache
=
V_cache
,
page_table
=
page_table
,
L_cache_per_b
=
L_cache_per_b
,
k_append_raw
=
k_append_raw
,
v_append_raw
=
v_append_raw
,
cu_seqlens_qk
=
cu_seqlens_qk
,
H_kv
=
workload
.
nk_heads
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
)
max_seqlen_q
=
int
((
cu_seqlens_qk
[
1
:]
-
cu_seqlens_qk
[:
-
1
]).
max
().
item
())
max_seqlen_k
=
int
((
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
max
().
item
())
max_seqlen_k_triton
=
seq_lens_bh
.
max
().
item
()
out_triton
=
causal_sparse_varlen_with_cache
(
q
=
q_raw
,
k_cache
=
K_cache
,
v_cache
=
V_cache
,
k
=
k_append_raw
,
v
=
v_append_raw
,
seq_lens_bh
=
seq_lens_bh
,
global_page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
cu_seqlens_qk
,
HKV
=
workload
.
nk_heads
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
sm_scale
=
sm_scale
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k_cache
=
max_seqlen_k_triton
,
)
out_flash
=
flash_attn_varlen_func
(
q
=
q_raw
,
k
=
K_total
,
v
=
V_total
,
cu_seqlens_q
=
cu_seqlens_qk
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
dropout_p
=
0.0
,
softmax_scale
=
sm_scale
,
causal
=
True
,
)
assert
torch
.
allclose
(
out_triton
,
out_flash
,
rtol
=
1e-6
,
atol
=
3e-3
)
max_diff
=
(
out_triton
-
out_flash
).
abs
().
max
().
item
()
logger
.
info
(
f
"[causal_sparse_varlen_with_cache:
{
workload
.
name
}
]: max abs diff=
{
max_diff
:
.
5
f
}
"
)
def
materialize_kv_cache_for_flash_decode
(
K_cache
,
V_cache
,
page_table
,
L_cache_per_b
,
# [B] int32
H_kv
:
int
,
PAGE_SIZE
:
int
,
):
"""
Build (K_flash, V_flash) suitable for flash_attn_with_kvcache, with shape:
(B, seqlen_cache_max, H_kv, D)
For each batch b:
- cache_seqlen[b] = L_cache_per_b[b]
- K_flash[b, :cache_seqlen[b], g] and V_flash[...] are filled from the paged KV cache.
- Tokens beyond cache_seqlen[b] (if any) are left as zeros and will be masked out
by flash_attn_with_kvcache via cache_seqlens.
"""
device
=
K_cache
.
device
dtype
=
K_cache
.
dtype
B
=
L_cache_per_b
.
shape
[
0
]
D
=
K_cache
.
shape
[
1
]
seqlen_cache_max
=
int
(
L_cache_per_b
.
max
().
item
())
K_flash
=
torch
.
zeros
((
B
,
seqlen_cache_max
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_flash
=
torch
.
zeros_like
(
K_flash
)
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
if
Lc
==
0
:
continue
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_flash
[
b
,
i
,
g
]
=
K_cache
[
idx
]
V_flash
[
b
,
i
,
g
]
=
V_cache
[
idx
]
return
K_flash
,
V_flash
@
pytest
.
mark
.
parametrize
(
"workload"
,
WORKLOADS_DECODE
,
ids
=
lambda
wl
:
wl
.
name
)
def
test_sparse_decode_attention
(
workload
:
Workload
):
dtype
=
torch
.
float16
device
=
triton
.
runtime
.
driver
.
active
.
get_active_torch_device
()
DEFAULT_PAGE_SIZE
=
256
N_LOGICAL_PAGES_MAX
=
256
# per-sequence cache lengths (all equal for WORKLOADS_DECODE)
L_cache_per_b
=
torch
.
as_tensor
(
workload
.
cache_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
# build paged KV cache used by the Triton kernel
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
=
(
build_paged_cache_from_lengths
(
B
=
workload
.
batch_size
,
H_kv
=
workload
.
nk_heads
,
D
=
workload
.
head_dim
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
L_cache_per_b
=
L_cache_per_b
,
device
=
device
,
dtype
=
dtype
,
)
)
B
=
workload
.
batch_size
HQ
=
workload
.
nq_heads
HKV
=
workload
.
nk_heads
D
=
workload
.
head_dim
# Triton kernel expects q: [B, HQ, D]
q_triton
=
torch
.
randn
(
B
,
HQ
,
D
,
device
=
device
,
dtype
=
dtype
)
batch_mapping
=
torch
.
arange
(
B
,
device
=
device
,
dtype
=
torch
.
int32
)
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
out_triton
=
head_sparse_decode_attention
(
q
=
q_triton
,
k
=
K_cache
,
v
=
V_cache
,
seq_lens_bh
=
seq_lens_bh
,
global_page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
HKV
=
HKV
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
sm_scale
=
sm_scale
,
)
# [B, HQ, D]
# materialize contiguous KV cache with shape [B, seqlen_cache_max, HKV, D]
K_flash
,
V_flash
=
materialize_kv_cache_for_flash_decode
(
K_cache
=
K_cache
,
V_cache
=
V_cache
,
page_table
=
page_table
,
L_cache_per_b
=
L_cache_per_b
,
H_kv
=
HKV
,
PAGE_SIZE
=
DEFAULT_PAGE_SIZE
,
)
# flash_attn_with_kvcache expects q: [B, seqlen_q, HQ, D]
q_flash
=
q_triton
.
unsqueeze
(
1
)
# seqlen_q = 1
out_flash
=
flash_attn_with_kvcache
(
q
=
q_flash
,
k_cache
=
K_flash
,
v_cache
=
V_flash
,
cache_seqlens
=
L_cache_per_b
,
softmax_scale
=
sm_scale
,
causal
=
True
,
).
squeeze
(
1
)
# [B, 1, HQ, D]
assert
torch
.
allclose
(
out_triton
,
out_flash
,
rtol
=
1e-6
,
atol
=
3e-3
)
max_diff
=
(
out_triton
-
out_flash
).
abs
().
max
().
item
()
logger
.
info
(
f
"[head_sparse_decode_attention:
{
workload
.
name
}
]: max abs diff=
{
max_diff
:
.
5
f
}
"
)
vllm/compactor-vllm/vllm_memory_comparison.png
0 → 100644
View file @
2b7160c6
79.4 KB
vllm/compactor-vllm/vllm_throughput_comparison.png
0 → 100644
View file @
2b7160c6
98.6 KB
vllm/entrypoints/llm.py
View file @
2b7160c6
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
itertools
import
os
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
...
@@ -95,6 +96,7 @@ from vllm.v1.engine.llm_engine import LLMEngine
...
@@ -95,6 +96,7 @@ from vllm.v1.engine.llm_engine import LLMEngine
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.kvprune.integration.compression_params
import
CompressionParams
from
vllm.v1.metrics.reader
import
Metric
from
vllm.v1.metrics.reader
import
Metric
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -184,6 +186,15 @@ class LLM:
...
@@ -184,6 +186,15 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If False, we will use CUDA graph and eager execution in hybrid.
kvprune_compression: If True, sets ``enforce_eager=True`` for the **v1**
engine only (no v1 CUDA graph capture). If ``None`` (default), read
``VLLM_KVPRUNE_COMPRESSION_DEFAULT`` (``"0"`` = allow v1 graphs;
``"1"`` = skip v1 graphs). This is independent of the compactor's
``LLMConfig.enforce_eager`` (see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` /
``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``; default tries compactor graphs).
When True, v1's GPU KV pool defaults to **one** block (minimum allowed by
the scheduler) unless ``num_gpu_blocks_override`` is passed in ``**kwargs``
or ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS`` is set (``auto`` = profiled allocation).
enable_return_routed_experts: Whether to return routed experts.
enable_return_routed_experts: Whether to return routed experts.
disable_custom_all_reduce: See
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
[ParallelConfig][vllm.config.ParallelConfig].
...
@@ -240,6 +251,7 @@ class LLM:
...
@@ -240,6 +251,7 @@ class LLM:
offload_prefetch_step
:
int
=
1
,
offload_prefetch_step
:
int
=
1
,
offload_params
:
set
[
str
]
|
None
=
None
,
offload_params
:
set
[
str
]
|
None
=
None
,
enforce_eager
:
bool
=
False
,
enforce_eager
:
bool
=
False
,
kvprune_compression
:
bool
|
None
=
None
,
enable_return_routed_experts
:
bool
=
False
,
enable_return_routed_experts
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
hf_token
:
bool
|
str
|
None
=
None
,
hf_token
:
bool
|
str
|
None
=
None
,
...
@@ -339,6 +351,26 @@ class LLM:
...
@@ -339,6 +351,26 @@ class LLM:
"'examples/offline_inference/data_parallel.py'."
"'examples/offline_inference/data_parallel.py'."
)
)
# v1 ``enforce_eager`` is independent of kvprune compactor ``LLMConfig.enforce_eager``.
if
kvprune_compression
is
None
:
_kvd
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPRESSION_DEFAULT"
,
"0"
).
strip
().
lower
()
kvprune_compression
=
_kvd
in
(
"1"
,
"true"
,
"yes"
)
if
kvprune_compression
:
enforce_eager
=
True
# Reserve minimal v1 GPU KV so compactor can use the rest of VRAM. v1
# scheduler requires num_gpu_blocks >= 1; profiling would allocate a
# large pool from gpu_memory_utilization. Override:
# VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS unset -> 1 block (default)
# VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto -> profiled (no override)
# VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=<int> -> max(1, int)
if
"num_gpu_blocks_override"
not
in
kwargs
:
_v1_kv
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS"
,
""
).
strip
()
if
_v1_kv
.
lower
()
in
(
"auto"
,
"profile"
):
pass
elif
not
_v1_kv
:
kwargs
[
"num_gpu_blocks_override"
]
=
1
else
:
kwargs
[
"num_gpu_blocks_override"
]
=
max
(
1
,
int
(
_v1_kv
))
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model
,
model
=
model
,
runner
=
runner
,
runner
=
runner
,
...
@@ -405,6 +437,9 @@ class LLM:
...
@@ -405,6 +437,9 @@ class LLM:
)
)
# Cache for __repr__ to avoid repeated collective_rpc calls
# Cache for __repr__ to avoid repeated collective_rpc calls
self
.
_cached_repr
:
str
|
None
=
None
self
.
_cached_repr
:
str
|
None
=
None
# Lazy compactor engine (``vllm.kvprune``) when :meth:`generate` uses compression.
self
.
_kvprune_compactor_engine
:
Any
=
None
self
.
_kvprune_compression_enabled
=
bool
(
kvprune_compression
)
def
get_tokenizer
(
self
)
->
TokenizerLike
:
def
get_tokenizer
(
self
)
->
TokenizerLike
:
return
self
.
llm_engine
.
get_tokenizer
()
return
self
.
llm_engine
.
get_tokenizer
()
...
@@ -446,6 +481,7 @@ class LLM:
...
@@ -446,6 +481,7 @@ class LLM:
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
lora_request
:
Sequence
[
LoRARequest
]
|
LoRARequest
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
priority
:
list
[
int
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
tokenization_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
compression
:
"CompressionParams | Sequence[CompressionParams] | None"
=
None
,
)
->
list
[
RequestOutput
]:
)
->
list
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -473,6 +509,15 @@ class LLM:
...
@@ -473,6 +509,15 @@ class LLM:
of `prompts`, where each priority value corresponds to the prompt
of `prompts`, where each priority value corresponds to the prompt
at the same index.
at the same index.
tokenization_kwargs: Overrides for `tokenizer.encode`.
tokenization_kwargs: Overrides for `tokenizer.encode`.
compression: Optional per-prompt KV compression (``vllm.kvprune``). If any
prompt has ``compression_ratio < 1.0``, the batch is run on the integrated
compactor engine with weights shared from this ``LLM``. Omit or use all
``compression_ratio >= 1`` to use the standard v1 engine only.
Use ``kvprune_compression=True`` or ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=1``
so the v1 engine skips CUDA graph capture. Compactor decode graphs
default on (``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` default ``1``) with
eager fallback if capture fails; set ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``
to skip compactor graph capture entirely.
Returns:
Returns:
A list of `RequestOutput` objects containing the
A list of `RequestOutput` objects containing the
...
@@ -485,6 +530,41 @@ class LLM:
...
@@ -485,6 +530,41 @@ class LLM:
"Try passing `--runner generate` to use the model as a "
"Try passing `--runner generate` to use the model as a "
"generative model."
"generative model."
)
)
compression_eff
=
compression
if
compression
is
None
and
getattr
(
self
,
"_kvprune_compression_enabled"
,
False
):
pc
=
self
.
llm_engine
.
vllm_config
.
parallel_config
if
(
pc
.
tensor_parallel_size
>
1
and
pc
.
pipeline_parallel_size
==
1
and
pc
.
data_parallel_size
==
1
):
from
vllm.kvprune.integration.compression_params
import
CompressionParams
from
vllm.kvprune.integration.compressed_generate
import
(
_normalize_prompt_list
,
)
_plist
=
_normalize_prompt_list
(
prompts
)
compression_eff
=
[
CompressionParams
(
compression_ratio
=
1.0
)
for
_
in
_plist
]
if
compression_eff
is
not
None
:
from
vllm.kvprune.integration.compressed_generate
import
(
try_compressed_generate
,
)
compressed_out
=
try_compressed_generate
(
self
,
prompts
,
sampling_params
,
compression
=
compression_eff
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
priority
=
priority
,
tokenization_kwargs
=
tokenization_kwargs
,
)
if
compressed_out
is
not
None
:
return
compressed_out
if
sampling_params
is
None
:
if
sampling_params
is
None
:
sampling_params
=
self
.
get_default_sampling_params
()
sampling_params
=
self
.
get_default_sampling_params
()
...
...
vllm/env_override.py
View file @
2b7160c6
...
@@ -4,6 +4,60 @@
...
@@ -4,6 +4,60 @@
import
importlib.util
import
importlib.util
import
os
import
os
# KV-prune (compactor) shared-weight integration needs the v1 engine in-process
# (`worker.get_model()` in the parent). Upstream defaults to multiprocess workers
# (`VLLM_ENABLE_V1_MULTIPROCESSING=1`). If unset, default to in-process so
# `LLM.generate(..., compression=...)` works without requiring env to be set
# before `import vllm`. Set `VLLM_ENABLE_V1_MULTIPROCESSING=1` to restore
# multiprocess workers.
if
"VLLM_ENABLE_V1_MULTIPROCESSING"
not
in
os
.
environ
:
os
.
environ
[
"VLLM_ENABLE_V1_MULTIPROCESSING"
]
=
"0"
# In-process EngineCore (``VLLM_ENABLE_V1_MULTIPROCESSING=0``) shares the process with
# user code; ``import vllm`` already runs ``import torch`` below. TP workers are then
# created via multiprocessing. If we use ``fork`` after CUDA has been initialized in
# the parent, PyTorch raises ``Cannot re-initialize CUDA in forked subprocess``.
# ``_maybe_force_spawn()`` can miss this when CUDA is still uninitialized at the
# moment ``get_mp_context()`` runs, so default to ``spawn`` for worker processes unless
# the user set ``VLLM_WORKER_MULTIPROC_METHOD`` explicitly.
os
.
environ
.
setdefault
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
# Tensor-parallel workers use NCCL, which queries **NVML for topology** (independent of
# PyTorch device counting). A faulty GPU on the host (e.g. ``nvidia-smi -L`` shows
# ``Unable to determine the device handle`` for one PCI address) often causes
# ``nvmlDeviceGetHandleByIndex(k) failed`` and then ``ncclCommInitRank`` errors.
# Mitigations: fix or isolate the bad GPU; or **before** ``import vllm`` restrict the
# container to healthy GPUs via UUID, e.g.
# export NVIDIA_VISIBLE_DEVICES=GPU-xxxx,GPU-yyyy,...
# (not only ``CUDA_VISIBLE_DEVICES=0,1,2,3``, which can still leave a dead GPU in
# NVML's enumeration). ``VLLM_KVPRUNE_NCCL_SAFE=1`` only tweaks P2P/IB, not NVML.
# For Docker, also consider ``--shm-size=10g`` or ``--ipc=host``.
if
os
.
environ
.
get
(
"VLLM_KVPRUNE_NCCL_SAFE"
,
""
).
strip
().
lower
()
in
(
"1"
,
"true"
,
"yes"
,
):
os
.
environ
.
setdefault
(
"NCCL_P2P_DISABLE"
,
"1"
)
os
.
environ
.
setdefault
(
"NCCL_IB_DISABLE"
,
"1"
)
# KV-prune: default ``LLM(kvprune_compression=None)`` to skip v1 CUDA graph capture
# (``enforce_eager=True`` on v1 only). Tests set ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=0``
# in ``tests/conftest.py`` before importing vLLM.
os
.
environ
.
setdefault
(
"VLLM_KVPRUNE_COMPRESSION_DEFAULT"
,
"1"
)
# Before first compactor init: opt-in sleep(level=1)+wake_up to discard v1 KV (tests/conftest
# also set 0). Default off now that kvprune path can use num_gpu_blocks_override=1 for v1.
os
.
environ
.
setdefault
(
"VLLM_KVPRUNE_RELEASE_V1_KV"
,
"0"
)
# Optional: ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` (fa_triton / pdtriton / pdfa) or legacy
# ``VLLM_KVPRUNE_ATTENTION_BACKEND`` see ``vllm/kvprune/integration/config_adapter.py``.
# Optional: ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1`` experimental compactor decode CUDA graphs.
#
# When ``LLM(..., kvprune_compression=True)`` (or default-on via
# ``VLLM_KVPRUNE_COMPRESSION_DEFAULT``), v1's ``num_gpu_blocks_override`` defaults
# to 1 in ``entrypoints/llm.py`` so the primary engine does not reserve a full
# profiled KV pool on the same GPU as the compactor. Use
# ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto`` for profiled blocks, or a positive int.
def
_get_torch_cuda_version
():
def
_get_torch_cuda_version
():
"""Peripheral function to _maybe_set_cuda_compatibility_path().
"""Peripheral function to _maybe_set_cuda_compatibility_path().
...
...
vllm/envs.py
View file @
2b7160c6
...
@@ -1030,6 +1030,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1030,6 +1030,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_V1_MULTIPROCESSING"
:
lambda
:
bool
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"1"
))
int
(
os
.
getenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"1"
))
),
),
# KV-prune / compactor integration (see ``vllm/env_override.py``, ``vllm/kvprune/``).
"VLLM_KVPRUNE_ATTENTION_SCHEDULE"
:
lambda
:
os
.
getenv
(
"VLLM_KVPRUNE_ATTENTION_SCHEDULE"
,
""
),
"VLLM_KVPRUNE_ATTENTION_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_KVPRUNE_ATTENTION_BACKEND"
,
""
),
"VLLM_KVPRUNE_COMPRESSION_DEFAULT"
:
lambda
:
os
.
getenv
(
"VLLM_KVPRUNE_COMPRESSION_DEFAULT"
,
""
),
"VLLM_KVPRUNE_RELEASE_V1_KV"
:
lambda
:
os
.
getenv
(
"VLLM_KVPRUNE_RELEASE_V1_KV"
,
""
),
"VLLM_KVPRUNE_NCCL_SAFE"
:
lambda
:
os
.
getenv
(
"VLLM_KVPRUNE_NCCL_SAFE"
,
""
),
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS"
:
lambda
:
os
.
getenv
(
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS"
,
""
),
"VLLM_LOG_BATCHSIZE_INTERVAL"
:
lambda
:
float
(
"VLLM_LOG_BATCHSIZE_INTERVAL"
:
lambda
:
float
(
os
.
getenv
(
"VLLM_LOG_BATCHSIZE_INTERVAL"
,
"-1"
)
os
.
getenv
(
"VLLM_LOG_BATCHSIZE_INTERVAL"
,
"-1"
)
),
),
...
@@ -1771,6 +1786,12 @@ def compile_factors() -> dict[str, object]:
...
@@ -1771,6 +1786,12 @@ def compile_factors() -> dict[str, object]:
"VLLM_ASSETS_CACHE_MODEL_CLEAN"
,
"VLLM_ASSETS_CACHE_MODEL_CLEAN"
,
"VLLM_WORKER_MULTIPROC_METHOD"
,
"VLLM_WORKER_MULTIPROC_METHOD"
,
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"VLLM_KVPRUNE_ATTENTION_SCHEDULE"
,
"VLLM_KVPRUNE_ATTENTION_BACKEND"
,
"VLLM_KVPRUNE_COMPRESSION_DEFAULT"
,
"VLLM_KVPRUNE_RELEASE_V1_KV"
,
"VLLM_KVPRUNE_NCCL_SAFE"
,
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS"
,
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE"
,
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE"
,
"VLLM_CPU_KVCACHE_SPACE"
,
"VLLM_CPU_KVCACHE_SPACE"
,
"VLLM_CPU_MOE_PREPACK"
,
"VLLM_CPU_MOE_PREPACK"
,
...
...
vllm/kvprune/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV-cache pruning (compactor-style) under ``vllm.kvprune``.
Use the standard :class:`~vllm.LLM` and pass ``compression=`` to :meth:`~vllm.LLM.generate`
with :class:`CompressionParams` when any prompt needs ``compression_ratio < 1``. The compactor
``LLMEngine`` + ``PagedKVCache`` shares weights with vLLM (no second checkpoint).
Subpackages (``attention``, ``kv_cache``, ``compression``, …) implement the compactor
engine.
"""
from
vllm.kvprune.compression.compression_config
import
CompressionMethod
from
vllm.kvprune.integration
import
CompressionParams
__all__
=
[
"CompressionMethod"
,
"CompressionParams"
,
]
vllm/kvprune/attention/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sparse attention Triton kernels (varlen prefill, decode, compile helpers)."""
from
vllm.kvprune.attention.sparse_varlen_kernel
import
causal_sparse_varlen_with_cache
__all__
=
[
"causal_sparse_varlen_with_cache"
]
vllm/kvprune/attention/compile_kernels.py
0 → 100644
View file @
2b7160c6
import
argparse
import
logging
import
math
import
torch
from
vllm.kvprune.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
build_mock_paged_cache_from_lengths
(
L_cache_per_b
:
torch
.
Tensor
,
HKV
:
int
,
D
:
int
,
PAGE_SIZE
:
int
,
N_LOGICAL_PAGES_MAX
:
int
,
device
,
dtype
,
):
B
=
len
(
L_cache_per_b
)
max_len
=
PAGE_SIZE
*
N_LOGICAL_PAGES_MAX
assert
(
L_cache_per_b
<=
max_len
).
all
()
seq_lens_bh
=
torch
.
empty
((
B
,
HKV
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
B
):
seq_lens_bh
[
b
,
:].
fill_
(
L_cache_per_b
[
b
])
num_phys_pages
=
B
*
HKV
*
N_LOGICAL_PAGES_MAX
CACHE_SIZE
=
num_phys_pages
*
PAGE_SIZE
K_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
V_cache
=
torch
.
zeros
((
CACHE_SIZE
,
D
),
device
=
device
,
dtype
=
dtype
)
page_table
=
torch
.
empty
(
(
B
,
HKV
,
N_LOGICAL_PAGES_MAX
),
device
=
device
,
dtype
=
torch
.
int32
)
# assign unique physical pages per (b, h, lp)
phys_page
=
0
for
b
in
range
(
B
):
for
h
in
range
(
HKV
):
for
lp
in
range
(
N_LOGICAL_PAGES_MAX
):
page_table
[
b
,
h
,
lp
]
=
phys_page
phys_page
+=
1
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
for
h
in
range
(
HKV
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b
,
h
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
)
V_cache
[
idx
]
=
torch
.
randn
(
D
,
device
=
device
,
dtype
=
dtype
)
return
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
def
autotune_causal_sparse_varlen_with_cache
(
*
,
max_length
:
int
=
16384
,
HKV
:
int
=
8
,
HQ
:
int
=
32
,
D
:
int
=
128
,
PAGE_SIZE
:
int
=
128
,
device
:
str
=
"cuda"
,
dtype
=
torch
.
float16
,
):
"""
Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
"""
import
itertools
import
tqdm
N_LOGICAL_PAGES_MAX
=
((
max_length
+
PAGE_SIZE
-
1
)
//
PAGE_SIZE
)
*
PAGE_SIZE
B
=
4
# D must be a power of two (kernel requirement).
assert
(
D
&
(
D
-
1
))
==
0
lengths_to_sweep
=
[
0
,
256
]
i
=
9
while
(
v
:
=
(
1
<<
i
))
<
max_length
:
lengths_to_sweep
.
append
(
v
)
i
+=
1
combos
=
list
(
itertools
.
product
(
lengths_to_sweep
,
repeat
=
2
))
logger
.
info
(
"tuning kernels. this may take a few minutes, "
"but only needs to be run once per LLMConfig"
)
for
cache_l
,
append_l
in
tqdm
.
tqdm
(
combos
):
if
cache_l
+
append_l
==
0
:
continue
L_cache_per_b
=
torch
.
tensor
(
[
cache_l
]
*
B
,
device
=
device
,
dtype
=
torch
.
int32
,
)
assert
(
L_cache_per_b
<=
PAGE_SIZE
*
N_LOGICAL_PAGES_MAX
).
all
()
K_cache
,
V_cache
,
page_table
,
seq_lens_bh
,
CACHE_SIZE
=
(
build_mock_paged_cache_from_lengths
(
L_cache_per_b
=
L_cache_per_b
,
HKV
=
HKV
,
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
device
=
device
,
dtype
=
dtype
,
)
)
L_app_list
=
[
append_l
]
*
B
cu
=
[
0
]
for
L
in
L_app_list
:
cu
.
append
(
cu
[
-
1
]
+
L
)
cu_seqlens_qk
=
torch
.
tensor
(
cu
,
dtype
=
torch
.
int32
,
device
=
device
)
N
=
int
(
cu_seqlens_qk
[
-
1
].
item
())
max_seqlen_q
=
int
((
cu_seqlens_qk
[
1
:]
-
cu_seqlens_qk
[:
-
1
]).
max
().
item
())
max_seqlen_k
=
seq_lens_bh
.
max
().
item
()
q_raw
=
torch
.
randn
(
N
,
HQ
,
D
,
device
=
device
,
dtype
=
dtype
)
k_append_raw
=
torch
.
randn
(
N
,
HKV
,
D
,
device
=
device
,
dtype
=
dtype
)
v_append_raw
=
torch
.
randn
(
N
,
HKV
,
D
,
device
=
device
,
dtype
=
dtype
)
# Identity batch mapping (local batch index == global)
batch_mapping
=
torch
.
arange
(
B
,
device
=
device
,
dtype
=
torch
.
int32
)
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
causal_sparse_varlen_with_cache
(
q
=
q_raw
,
k_cache
=
K_cache
,
v_cache
=
V_cache
,
k
=
k_append_raw
,
v
=
v_append_raw
,
seq_lens_bh
=
seq_lens_bh
,
global_page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
cu_seqlens_qk
,
HKV
=
HKV
,
PAGE_SIZE
=
PAGE_SIZE
,
sm_scale
=
sm_scale
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k_cache
=
max_seqlen_k
,
)
def
_parse_args
()
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Autotune Triton kernels. "
"Results are cached, so this should only need to be run once per configuration."
"This script doesn't need to be run, as the kernels will be autotuned at runtime"
"if no cached autotuning data exists. Running this before hand will prevent run-time"
"autotuning, which will accelerate compactor-vllm at inference time."
)
parser
.
add_argument
(
"--max-length"
,
type
=
int
,
default
=
16384
,
help
=
"Maximum total sequence length to consider."
,
)
parser
.
add_argument
(
"--HKV"
,
type
=
int
,
default
=
8
,
help
=
"Number of KV heads."
,
)
parser
.
add_argument
(
"--HQ"
,
type
=
int
,
default
=
32
,
help
=
"Number of query heads."
,
)
parser
.
add_argument
(
"--D"
,
type
=
int
,
default
=
128
,
help
=
"Per-head hidden dimension (must be power of 2)."
,
)
parser
.
add_argument
(
"--page-size"
,
type
=
int
,
default
=
128
,
help
=
"Page size (tokens per physical page)."
,
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
help
=
"Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu')."
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
help
=
"Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}."
,
)
parser
.
add_argument
(
"--log-level"
,
type
=
str
,
default
=
"INFO"
,
choices
=
[
"CRITICAL"
,
"ERROR"
,
"WARNING"
,
"INFO"
,
"DEBUG"
],
help
=
"Logging level."
,
)
return
parser
.
parse_args
()
def
_resolve_dtype
(
dtype_str
:
str
):
s
=
dtype_str
.
lower
()
if
s
in
(
"float16"
,
"fp16"
,
"half"
):
return
torch
.
float16
if
s
in
(
"bfloat16"
,
"bf16"
):
return
torch
.
bfloat16
if
s
in
(
"float32"
,
"fp32"
):
return
torch
.
float32
raise
ValueError
(
f
"Unsupported dtype:
{
dtype_str
}
"
)
def
main
():
args
=
_parse_args
()
logging
.
basicConfig
(
level
=
getattr
(
logging
,
args
.
log_level
.
upper
()),
format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
,
)
dtype
=
_resolve_dtype
(
args
.
dtype
)
logger
.
info
(
"Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
"device=%s, dtype=%s"
,
args
.
max_length
,
args
.
HKV
,
args
.
HQ
,
args
.
D
,
args
.
page_size
,
args
.
device
,
dtype
,
)
autotune_causal_sparse_varlen_with_cache
(
max_length
=
args
.
max_length
,
HKV
=
args
.
HKV
,
HQ
=
args
.
HQ
,
D
=
args
.
D
,
PAGE_SIZE
=
args
.
page_size
,
device
=
args
.
device
,
dtype
=
dtype
,
)
if
__name__
==
"__main__"
:
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"%(asctime)s %(levelname)s: %(message)s"
,
)
main
()
vllm/kvprune/attention/fa_paged_bridge.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
"""FlashAttention paths over compactor paged KV (materialize + FA ops).
Used when :class:`~vllm.kvprune.config.engine_config.KvpruneAttentionSchedule`
selects FlashAttention for prefill and/or decode while KV **writes** remain on
Triton (``prefill_store_*``, ``decode_store_kv``).
**Why compactor-vllm looked fine but kvprune ``fa_triton`` + compression did not**
compactor-vllm ``layers/attention.py`` (prefill)::
use_flash_prefill = (backend == FLASH) or (COMPACTOR_TRITON and not do_compression)
if use_flash_prefill:
flash_attn_varlen_func(q, k, v, ...) # dense packed Q/K/V, one length per batch
elif COMPACTOR_TRITON:
causal_sparse_varlen_with_cache(..., seq_lens_bh=...) # paged KV, **per-(b,h)** lengths
So **with compression** (``do_compression``), compactor-vllm **never** runs FlashAttention on
paged top-K KV; it always uses Triton ``causal_sparse_varlen_with_cache``.
kvprune ``fa_triton`` (``FA_PREFILL_TRITON_DECODE``) keeps the intended split: **FA prefill**
+ **Triton decode**. For compressed prefill it calls :func:`flash_prefill_from_paged`, which
builds a dense ``[total_k, H_kv, D]`` tensor and calls ``flash_attn_varlen_func``. That layout
assumes **one cache prefix length per batch row shared by all KV heads** (same ``Lc`` for every
``g`` when copying from ``k_cache``). Top-K retention instead updates ``bh_lens`` with
**different** counts per head (``seq_lens_bh`` shape ``[B, HKV]``). Taking ``max(dim=1)``
(older code) used one ``Lc`` per batch but still filled ``K_total[offset+i, g]`` for every head
``g`` — heads with **shorter** real cache were **over-read**, corrupting attention.
We therefore **require** ``seq_lens_bh[b, :]`` to be constant in ``h`` for each ``b`` before
materializing for FA (see :func:`_require_uniform_kv_lens_per_batch_for_fa_materialize`). If your
retention policy yields unequal per-head lengths, use ``pdtriton`` (Triton prefill) for that
run, or disable compression while using ``fa_triton``.
"""
from
__future__
import
annotations
import
math
from
typing
import
TYPE_CHECKING
import
torch
from
flash_attn.flash_attn_interface
import
flash_attn_func
,
flash_attn_varlen_func
if
TYPE_CHECKING
:
pass
def
_require_uniform_kv_lens_per_batch_for_fa_materialize
(
seq_lens_bh
:
torch
.
Tensor
,
*
,
caller
:
str
)
->
None
:
"""FlashAttention varlen + dense ``[total_k, H_kv, D]`` layout needs one K length per batch."""
if
seq_lens_bh
.
ndim
!=
2
:
raise
ValueError
(
f
"
{
caller
}
: expected seq_lens_bh [B, HKV], got
{
seq_lens_bh
.
shape
}
"
)
row_min
=
seq_lens_bh
.
min
(
dim
=
1
).
values
row_max
=
seq_lens_bh
.
max
(
dim
=
1
).
values
if
not
bool
((
row_min
==
row_max
).
all
().
item
()):
raise
RuntimeError
(
f
"
{
caller
}
: FlashAttention materialization needs identical cached KV lengths "
"across KV heads for each batch row (seq_lens_bh[b, :] constant in h). "
f
"Got per-batch min/max mismatch: min=
{
row_min
.
tolist
()
}
max=
{
row_max
.
tolist
()
}
. "
"Typical top-K compression uses different counts per head; compactor-vllm uses "
"Triton causal_sparse_varlen_with_cache in that case, not FA on materialized paged KV. "
"Use schedule ``pdtriton`` (Triton prefill + Triton decode), or disable compression "
"for this model run with ``fa_triton``."
)
def
materialize_kv_for_flash_prefill
(
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
L_cache_per_b
:
torch
.
Tensor
,
k_append
:
torch
.
Tensor
,
v_append
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
H_kv
:
int
,
PAGE_SIZE
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Build packed K/V for :func:`flash_attn_varlen_func` (cache prefix + append)."""
device
=
k_cache
.
device
dtype
=
k_cache
.
dtype
B
=
cu_seqlens_q
.
numel
()
-
1
N
,
H_kv_raw
,
D
=
k_append
.
shape
assert
H_kv_raw
==
H_kv
L_app
=
(
cu_seqlens_q
[
1
:]
-
cu_seqlens_q
[:
-
1
]).
to
(
torch
.
int32
)
seqlen_k
=
L_cache_per_b
.
to
(
torch
.
int32
)
+
L_app
cu_seqlens_k
=
torch
.
empty
(
B
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
cu_seqlens_k
[
0
]
=
0
total_k
=
int
(
seqlen_k
.
sum
().
item
())
K_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_total
=
torch
.
empty
((
total_k
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
for
b
in
range
(
B
):
offset_k
=
int
(
cu_seqlens_k
[
b
].
item
())
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
La
=
int
(
L_app
[
b
].
item
())
q_start
=
int
(
cu_seqlens_q
[
b
].
item
())
b_true
=
int
(
batch_mapping
[
b
].
item
())
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b_true
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_total
[
offset_k
+
i
,
g
]
=
k_cache
[
idx
]
V_total
[
offset_k
+
i
,
g
]
=
v_cache
[
idx
]
for
g
in
range
(
H_kv
):
for
j
in
range
(
La
):
src
=
q_start
+
j
dst
=
offset_k
+
Lc
+
j
K_total
[
dst
,
g
]
=
k_append
[
src
,
g
]
V_total
[
dst
,
g
]
=
v_append
[
src
,
g
]
cu_seqlens_k
[
b
+
1
]
=
cu_seqlens_k
[
b
]
+
(
Lc
+
La
)
return
K_total
,
V_total
,
cu_seqlens_k
def
flash_prefill_from_paged
(
q
:
torch
.
Tensor
,
k_append
:
torch
.
Tensor
,
v_append
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
*
,
seq_lens_bh_before
:
torch
.
Tensor
,
global_page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
PAGE_SIZE
:
int
,
HKV
:
int
,
sm_scale
:
float
|
None
,
)
->
torch
.
Tensor
:
"""Prefill attention via FlashAttention-2 varlen after materializing paged KV + append."""
_require_uniform_kv_lens_per_batch_for_fa_materialize
(
seq_lens_bh_before
,
caller
=
"flash_prefill_from_paged"
)
L_cache_per_b
=
seq_lens_bh_before
.
max
(
dim
=
1
).
values
.
to
(
torch
.
int32
)
K_total
,
V_total
,
cu_seqlens_k
=
materialize_kv_for_flash_prefill
(
k_cache
,
v_cache
,
global_page_table
,
batch_mapping
,
L_cache_per_b
,
k_append
,
v_append
,
cu_seqlens_q
,
HKV
,
PAGE_SIZE
,
)
max_seqlen_k
=
int
((
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
max
().
item
())
return
flash_attn_varlen_func
(
q
,
K_total
,
V_total
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
sm_scale
if
sm_scale
is
not
None
else
None
,
causal
=
True
,
)
def
materialize_kv_cache_for_flash_decode
(
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
L_cache_per_b
:
torch
.
Tensor
,
H_kv
:
int
,
PAGE_SIZE
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Dense ``[B, S, H_kv, D]`` cache for :func:`flash_attn_func` decode."""
device
=
k_cache
.
device
dtype
=
k_cache
.
dtype
B
=
L_cache_per_b
.
shape
[
0
]
D
=
k_cache
.
shape
[
1
]
seqlen_cache_max
=
int
(
L_cache_per_b
.
max
().
item
())
K_flash
=
torch
.
zeros
((
B
,
seqlen_cache_max
,
H_kv
,
D
),
device
=
device
,
dtype
=
dtype
)
V_flash
=
torch
.
zeros_like
(
K_flash
)
for
b
in
range
(
B
):
Lc
=
int
(
L_cache_per_b
[
b
].
item
())
if
Lc
==
0
:
continue
b_true
=
int
(
batch_mapping
[
b
].
item
())
for
g
in
range
(
H_kv
):
for
i
in
range
(
Lc
):
lp
=
i
//
PAGE_SIZE
off
=
i
%
PAGE_SIZE
phys
=
int
(
page_table
[
b_true
,
g
,
lp
].
item
())
idx
=
phys
*
PAGE_SIZE
+
off
K_flash
[
b
,
i
,
g
]
=
k_cache
[
idx
]
V_flash
[
b
,
i
,
g
]
=
v_cache
[
idx
]
return
K_flash
,
V_flash
def
flash_decode_from_paged
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
*
,
seq_lens_bh
:
torch
.
Tensor
,
global_page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
PAGE_SIZE
:
int
,
HKV
:
int
,
sm_scale
:
float
|
None
,
)
->
torch
.
Tensor
:
"""Decode step via FA: ``decode_store_kv`` has already appended the new K/V row."""
_require_uniform_kv_lens_per_batch_for_fa_materialize
(
seq_lens_bh
,
caller
=
"flash_decode_from_paged"
)
L_cache_per_b
=
seq_lens_bh
.
max
(
dim
=
1
).
values
.
to
(
torch
.
int32
)
K_flash
,
V_flash
=
materialize_kv_cache_for_flash_decode
(
k_cache
,
v_cache
,
global_page_table
,
batch_mapping
,
L_cache_per_b
,
HKV
,
PAGE_SIZE
,
)
B
,
HQ
,
D
=
q
.
shape
q_b
=
q
.
unsqueeze
(
1
)
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
# One query position attends to all L keys already materialized in K/V (no causal mask).
out
=
flash_attn_func
(
q_b
,
K_flash
,
V_flash
,
softmax_scale
=
sm_scale
,
causal
=
False
,
)
return
out
.
squeeze
(
1
)
vllm/kvprune/attention/sparse_decode_kernel.py
0 → 100644
View file @
2b7160c6
import
functools
import
math
import
torch
import
triton
import
triton.language
as
tl
from
vllm.kvprune.utils.triton_compat
import
(
autotune
as
triton_autotune
,
maybe_set_allocator
,
)
def
head_sparse_decode_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
seq_lens_bh
:
torch
.
Tensor
,
global_page_table
:
torch
.
Tensor
,
batch_mapping
:
torch
.
Tensor
,
HKV
:
int
,
PAGE_SIZE
:
int
,
sm_scale
:
float
=
None
,
key_split
:
int
=
None
,
):
"""
Decode-time head-sparse attention over a paged KV cache.
This is a wrapper around the Triton decode kernel used during incremental
generation. For each batch, we read the cached keys
and values from a global paged KV buffer, apply causal attention with one
new query token, and return the attention output.
The KV cache is stored in a single global K/V tensor of shape
``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
(batch, kv_head, token_idx) is mapped to a physical row in the cache by:
1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
Grouped-query attention (GQA / MQA) is supported by passing more query
heads than KV heads (``HQ`` must be a multiple of ``HKV``).
Args:
:param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]``
containing the new decode tokens for each sequence in the launch batch.
:param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
backing buffer for all (batch, head) KV pages.
:param v: Global value cache of shape ``[CACHE_SIZE, D]``.
:param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
local batch index and KV head, the number of valid cached tokens
in the paged KV cache.
:param global_page_table: Tensor of shape
``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
``(true_batch_idx, kv_head, logical_page)`` to a physical page id
in the global cache.
:param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
index used by this call to the true batch row used to index
``global_page_table``.
:param HKV: Number of KV heads.
:param PAGE_SIZE: Number of tokens stored per physical KV page.
:param sm_scale: Optional scaling factor applied to the attention logits
before softmax. If ``None``, ``1 / sqrt(D)`` is used.
:param key_split: Optional number of splits along the key sequence length.
If > 1, the kernel will process the KV sequence in ``key_split``
chunks to reduce on-chip memory usage. If ``None`` or 0, a
heuristic is used.
Returns:
:return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
device and dtype as ``q``.
"""
with
torch
.
cuda
.
device
(
q
.
device
):
if
q
.
ndim
!=
3
:
assert
q
.
ndim
==
4
B
,
HQ
,
S
,
D
=
q
.
shape
assert
S
==
1
,
"head_sparse_decode_attention only supports q_len=1"
q
=
q
.
squeeze
(
-
2
)
elif
q
.
ndim
==
3
:
B
,
HQ
,
D
=
q
.
shape
CACHE_SIZE
=
k
.
shape
[
0
]
assert
PAGE_SIZE
%
32
==
0
,
"PAGE_SIZE must be divisible by 32"
GROUP_M
=
HQ
//
HKV
assert
GROUP_M
*
HKV
==
HQ
,
"HQ must be divisible by H_kv"
FP8
=
hasattr
(
torch
,
"float8_e5m2"
)
and
q
.
dtype
==
torch
.
float8_e5m2
seq_lens_bh
=
seq_lens_bh
.
to
(
torch
.
int32
)
assert
B
<=
32767
,
"too many batches"
assert
global_page_table
.
shape
[
1
]
==
HKV
assert
q
.
is_contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
global_page_table
=
global_page_table
.
contiguous
()
batch_mapping
=
batch_mapping
.
contiguous
()
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
N_LOGICAL_PAGES_MAX
=
global_page_table
.
shape
[
-
1
]
sm_scale
=
1
/
math
.
sqrt
(
D
)
if
sm_scale
is
None
else
sm_scale
if
key_split
is
None
:
# round max_seq_len to the next power of two to maximize cache hits
key_split
=
num_splits_heuristic
(
B
*
HKV
,
max_seq_len
=
1
<<
int
(
seq_lens_bh
.
max
()).
bit_length
(),
num_sms
=
torch
.
cuda
.
get_device_properties
(
q
.
device
).
multi_processor_count
,
max_splits
=
12
,
)
maybe_set_allocator
(
lambda
size
,
align
,
_
:
torch
.
empty
(
size
,
dtype
=
torch
.
int8
,
device
=
q
.
device
)
)
# stage 1 scratch
mid_o
=
torch
.
empty
((
B
,
key_split
,
HQ
,
D
),
device
=
q
.
device
,
dtype
=
q
.
dtype
)
mid_lse
=
torch
.
empty
((
B
,
key_split
,
HQ
),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
# processes all queries for a KV head together
# pointers are lowercase, CONSTANTS are upper
grid1
=
(
B
,
HKV
,
key_split
)
_varkv_stage1_groupM
[
grid1
](
q
=
q
,
k
=
k
,
v
=
v
,
mid_o
=
mid_o
,
mid_lse
=
mid_lse
,
page_table_bhl
=
global_page_table
,
batch_mapping
=
batch_mapping
,
seq_lens_bh
=
seq_lens_bh
.
contiguous
(),
SM_SCALE
=
sm_scale
,
B
=
B
,
HKV
=
HKV
,
HQ
=
HQ
,
CACHE_SIZE
=
CACHE_SIZE
,
STRIDE_LBS
=
mid_lse
.
stride
(
0
),
STRIDE_LS
=
mid_lse
.
stride
(
1
),
STRIDE_LH
=
mid_lse
.
stride
(
2
),
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
D
=
D
,
KEY_SPLIT
=
key_split
,
GROUP_M
=
GROUP_M
,
DTYPE
=
tl
.
float8e5
if
FP8
else
(
tl
.
bfloat16
if
q
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
),
PAGE_SIZE
=
PAGE_SIZE
,
)
if
key_split
==
1
:
return
mid_o
.
squeeze
(
1
).
contiguous
()
# reduce partial results across splits
output
=
torch
.
empty_like
(
q
)
grid2
=
(
B
,
HQ
)
_varkv_stage2_reduce
[
grid2
](
mid_o
=
mid_o
,
mid_lse
=
mid_lse
,
output
=
output
,
STRIDE_LBS
=
mid_lse
.
stride
(
0
),
STRIDE_LS
=
mid_lse
.
stride
(
1
),
STRIDE_LH
=
mid_lse
.
stride
(
2
),
STRIDE_OBS
=
output
.
stride
(
0
),
STRIDE_OH
=
output
.
stride
(
1
),
B
=
B
,
HQ
=
HQ
,
D
=
D
,
# type: ignore
KEY_SPLIT
=
key_split
,
# type: ignore
DTYPE
=
tl
.
float8e5
if
FP8
else
(
tl
.
bfloat16
if
q
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
),
)
return
output
# similar to flash attention split heuristic
@
functools
.
lru_cache
(
maxsize
=
128
)
def
num_splits_heuristic
(
total_mblocks
:
int
,
max_seq_len
:
int
,
num_sms
:
int
,
max_splits
:
int
,
)
->
int
:
# If we nearly fill SMs already, prefer 1 split
if
total_mblocks
>=
0.8
*
num_sms
or
max_seq_len
<=
1024
:
return
1
eff
=
[]
max_eff
=
0.0
for
s
in
range
(
1
,
min
(
max_splits
,
num_sms
)
+
1
):
if
(
max_seq_len
/
s
)
<=
512
:
break
n_waves
=
float
(
total_mblocks
*
s
)
/
float
(
num_sms
)
e
=
n_waves
/
math
.
ceil
(
n_waves
)
if
n_waves
>
0
else
0.0
eff
.
append
(
e
)
max_eff
=
max
(
max_eff
,
e
)
threshold
=
0.75
*
max_eff
# if not split_min_hit else 0.9 * max_eff
for
i
,
e
in
enumerate
(
eff
,
start
=
1
):
if
e
>=
threshold
:
return
i
return
1
def
prune_invalid_configs
(
configs
,
_
,
**
kwargs
):
PAGE_SIZE
=
kwargs
[
"PAGE_SIZE"
]
return
[
conf
for
conf
in
configs
if
conf
.
kwargs
.
get
(
"BLOCK_N"
,
0
)
<=
PAGE_SIZE
]
@
triton_autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_N"
:
BLOCK_N
,
"MIN_BLOCK_KV"
:
MIN_BLOCK_KV
,
"WARPSPEC"
:
ws
},
num_warps
=
w
,
num_stages
=
s
,
)
for
BLOCK_N
in
[
32
,
64
,
128
]
for
MIN_BLOCK_KV
in
[
8
]
for
s
in
[
2
,
3
,
4
]
for
w
in
[
4
,
8
]
for
ws
in
[
True
,
False
]
],
key
=
[
"HKV"
,
"GROUP_M"
,
"D"
,
"PAGE_SIZE"
,
# "B"
],
cache_results
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
prune_invalid_configs
},
)
@
triton
.
jit
def
_varkv_stage1_groupM
(
q
,
# [B, HQ, D] contiguous
k
,
# GLOBAL cache: [CACHE_SIZE, D], contiguous
v
,
# GLOBAL cache: [CACHE_SIZE, D], contiguous
mid_o
,
mid_lse
,
page_table_bhl
,
# int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
batch_mapping
,
# int32 [B] maps local pid_b -> true batch index
seq_lens_bh
,
# int32 [B*H_kv] valid tokens per (b,h)
SM_SCALE
,
B
,
HKV
,
HQ
,
CACHE_SIZE
,
# CACHE_SIZE = N_PAGES * PAGE_SIZE
STRIDE_LBS
,
STRIDE_LS
,
STRIDE_LH
,
# constexprs
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
# page table width per (b,h)
D
:
tl
.
constexpr
,
KEY_SPLIT
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
DTYPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
# batch
pid_kvh
=
tl
.
program_id
(
1
)
# kv head
pid_s
=
tl
.
program_id
(
2
)
# split
# valid length L for this (b,h)
bh_stride
=
HKV
L
=
tl
.
load
(
seq_lens_bh
+
pid_b
*
bh_stride
+
pid_kvh
)
if
L
==
0
:
return
tl
.
assume
(
L
>
0
)
# split sizing on logical token axis [0..L)
base
=
tl
.
cdiv
(
L
,
KEY_SPLIT
)
per_split_len
=
tl
.
cdiv
(
base
,
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
split_start
=
pid_s
*
per_split_len
split_end
=
tl
.
minimum
(
split_start
+
per_split_len
,
L
)
# query heads mapped to this kv head
base_qh
=
pid_kvh
*
GROUP_M
GROUP_M_PAD
:
tl
.
constexpr
=
16
if
GROUP_M
<
16
else
GROUP_M
offs_m
=
tl
.
arange
(
0
,
GROUP_M_PAD
)
mask_m
=
offs_m
<
GROUP_M
offs_d
=
tl
.
arange
(
0
,
D
)
# load Q tile [M, D]
q_ptrs
=
q
+
(
pid_b
*
HQ
+
base_qh
+
offs_m
)[:,
None
]
*
D
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
q_ptrs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
).
to
(
DTYPE
)
# [M, D]
# streaming softmax state per query
e_max
=
tl
.
zeros
([
GROUP_M_PAD
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
GROUP_M_PAD
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
GROUP_M_PAD
,
D
],
dtype
=
tl
.
float32
)
if
split_end
>
split_start
:
# logical pages covering [split_start, split_end)
lp0
=
split_start
//
PAGE_SIZE
lp1
=
tl
.
cdiv
(
split_end
,
PAGE_SIZE
)
# exclusive
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
tl
.
assume
(
mapped_b
>=
0
)
# page table base for this (b,h)
pt_stride
=
N_LOGICAL_PAGES_MAX
pt_base
=
(
mapped_b
*
HKV
+
pid_kvh
)
*
pt_stride
for
lp
in
tl
.
range
(
lp0
,
lp1
):
phys
=
tl
.
load
(
page_table_bhl
+
pt_base
+
lp
,
cache_modifier
=
".cg"
)
# physical page id
# bounds within the logical page
local_start
=
tl
.
where
(
lp
==
lp0
,
split_start
-
lp
*
PAGE_SIZE
,
0
)
local_end
=
tl
.
where
(
lp
==
(
lp1
-
1
),
split_end
-
lp
*
PAGE_SIZE
,
PAGE_SIZE
)
page_base
=
phys
*
PAGE_SIZE
page_base
=
tl
.
multiple_of
(
page_base
,
BLOCK_N
)
for
s
in
tl
.
range
(
local_start
,
local_end
,
BLOCK_N
):
s
=
tl
.
multiple_of
(
s
,
MIN_BLOCK_KV
)
offs_bn
=
tl
.
arange
(
0
,
BLOCK_N
)
key_idx
=
page_base
+
s
+
offs_bn
k_ptrs
=
k
+
key_idx
[:,
None
]
*
D
+
offs_d
[
None
,
:]
k_blk
=
tl
.
load
(
k_ptrs
,
mask
=
(
key_idx
<
CACHE_SIZE
)[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k_blk
.
T
)
*
SM_SCALE
# [M, BN]
offs_n
=
s
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
local_end
qk
=
tl
.
where
(
mask_n
[
None
,
:],
qk
,
-
float
(
"inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
# [M]
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
# [M]
acc
=
acc
*
re_scale
[:,
None
]
# [M, D]
v_ptrs
=
v
+
key_idx
[:,
None
]
*
D
+
offs_d
[
None
,
:]
v_blk
=
tl
.
load
(
v_ptrs
,
mask
=
(
key_idx
<
CACHE_SIZE
)[:,
None
],
other
=
0.0
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
# [M, BN]
acc
=
tl
.
dot
(
p
.
to
(
DTYPE
),
v_blk
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# write mid outputs [M, D] for this split
tmp
=
(
acc
/
e_sum
[:,
None
]).
to
(
DTYPE
)
row_mid
=
pid_b
*
(
KEY_SPLIT
*
HQ
)
+
pid_s
*
HQ
+
base_qh
+
offs_m
mid_ptrs
=
mid_o
+
row_mid
[:,
None
]
*
D
+
offs_d
[
None
,
:]
tl
.
store
(
mid_ptrs
,
tmp
,
mask
=
mask_m
[:,
None
])
ml_ptrs
=
(
mid_lse
+
pid_b
*
STRIDE_LBS
+
pid_s
*
STRIDE_LS
+
(
base_qh
+
offs_m
)
*
STRIDE_LH
)
safe_sum
=
tl
.
where
(
mask_m
,
e_sum
,
1.0
)
tl
.
store
(
ml_ptrs
,
e_max
+
tl
.
log
(
safe_sum
),
mask
=
mask_m
)
else
:
# empty split
zero_md
=
tl
.
zeros
([
GROUP_M_PAD
,
D
],
dtype
=
DTYPE
)
row_mid
=
pid_b
*
(
KEY_SPLIT
*
HQ
)
+
pid_s
*
HQ
+
base_qh
+
offs_m
mid_ptrs
=
mid_o
+
row_mid
[:,
None
]
*
D
+
offs_d
[
None
,
:]
tl
.
store
(
mid_ptrs
,
zero_md
,
mask
=
mask_m
[:,
None
])
ml_ptrs
=
(
mid_lse
+
pid_b
*
STRIDE_LBS
+
pid_s
*
STRIDE_LS
+
(
base_qh
+
offs_m
)
*
STRIDE_LH
)
tl
.
store
(
ml_ptrs
,
-
float
(
"inf"
),
mask
=
mask_m
)
@
triton
.
jit
def
_varkv_stage2_reduce
(
mid_o
,
mid_lse
,
output
,
STRIDE_LBS
,
STRIDE_LS
,
STRIDE_LH
,
STRIDE_OBS
,
STRIDE_OH
,
B
,
HQ
,
D
:
tl
.
constexpr
,
KEY_SPLIT
:
tl
.
constexpr
,
DTYPE
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
D
)
# across split LSE combine
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
D
],
dtype
=
tl
.
float32
)
for
s
in
tl
.
range
(
KEY_SPLIT
):
row_mid
=
pid_b
*
(
KEY_SPLIT
*
HQ
)
+
s
*
HQ
+
pid_h
tv
=
tl
.
load
(
mid_o
+
row_mid
*
D
+
offs_d
).
to
(
DTYPE
)
tl_ptr
=
mid_lse
+
pid_b
*
STRIDE_LBS
+
s
*
STRIDE_LS
+
pid_h
*
STRIDE_LH
tlogic
=
tl
.
load
(
tl_ptr
)
n_e_max
=
tl
.
maximum
(
e_max
,
tlogic
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
=
acc
*
old_scale
+
tl
.
exp
(
tlogic
-
n_e_max
)
*
tv
.
to
(
tl
.
float32
)
e_sum
=
e_sum
*
old_scale
+
tl
.
exp
(
tlogic
-
n_e_max
)
e_max
=
n_e_max
o
=
(
acc
/
e_sum
).
to
(
DTYPE
)
o_ptr
=
output
+
pid_b
*
STRIDE_OBS
+
pid_h
*
STRIDE_OH
+
offs_d
tl
.
store
(
o_ptr
,
o
)
vllm/kvprune/attention/sparse_varlen_kernel.py
0 → 100644
View file @
2b7160c6
import
logging
import
math
import
torch
import
triton
import
triton.language
as
tl
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
from
vllm.kvprune.utils.triton_compat
import
(
autotune
as
triton_autotune
,
cuda_capability_geq
,
maybe_set_allocator
,
)
logger
=
logging
.
getLogger
(
__name__
)
def
_causal_appended_only_exact
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
*
,
sm_scale
:
float
,
max_seqlen_q
:
int
,
)
->
torch
.
Tensor
:
"""Exact zero-prefix prefill attention over appended q/k/v only.
This is the mathematically correct subcase of
:func:`causal_sparse_varlen_with_cache` when there is no cached KV prefix.
It avoids the problematic Triton on-band appended branch while preserving
``pdtriton`` semantics for later cached-prefix steps. Use the same
``flash_attn_varlen_func`` path as the debug reference so this subcase is
numerically identical to the known-good result.
"""
return
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
sm_scale
,
causal
=
True
,
)
def
causal_sparse_varlen_with_cache
(
q
,
k
,
v
,
k_cache
,
v_cache
,
seq_lens_bh
,
global_page_table
,
batch_mapping
,
cu_seqlens_q
,
max_seqlen_q
:
int
,
max_seqlen_k_cache
:
int
,
HKV
:
int
,
PAGE_SIZE
:
int
,
sm_scale
=
None
,
):
"""
Causal prefill attention over a paged KV cache plus a block of newly
appended tokens in a packed batch format.
This function wraps the Triton kernel
``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
a batch of variable-length sequences, where:
鈥?Past keys/values are stored in a paged global KV cache
(``k_cache``, ``v_cache``) with a (per-layer) page table.
鈥?New tokens for this step are given as K/V blocks
(``k``, ``v``), together with a packed query block ``q``.
鈥?The result is equivalent to applying causal attention over the
concatenation of:
[ cached KV prefix || (K_app, V_app) for this step ]
for each sequence in the batch.
Grouped-query attention (GQA / MQA) is supported by allowing more query
heads than KV heads: ``HQ`` must be divisible by ``HKV``.
Args:
:param q:
Query tensor of shape ``[N, HQ, D]`` (float16 / bfloat16/float32).
``N`` is the total number of new tokens across the batch
(i.e. ``N = sum_b seqlen_q[b]``), packed according to
``cu_seqlens_q``. ``HQ`` is the number of query heads, ``D`` the
head dimension (must be a power of two).
:param k:
New key tensor of shape ``[N, HKV, D]`` for the same tokens as
``q``. These are the K values appended to the cache for this
prefill step.
:param v:
New value tensor of shape ``[N, HKV, D]`` for the same tokens as
``q``.
:param k_cache:
Global key cache backing buffer of shape ``[CACHE_SIZE, D]``.
Keys for all cached tokens and heads are stored here; the mapping
from (batch, head, token index) to a row in this buffer is
given by ``global_page_table``.
:param v_cache:
Global value cache of shape ``[CACHE_SIZE, D]``. Must have the
same layout as ``k_cache`` (same ``CACHE_SIZE`` and ``D``).
:param seq_lens_bh:
Tensor of shape ``[B, HKV]`` (int32) giving, for each local batch
index and KV head, the number of cached tokens already present
in the paged KV cache before this prefill step.
:param global_page_table:
Tensor of shape ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32)
mapping ``(true_batch_idx, kv_head, logical_page)`` to a physical
page id in the global KV cache. A physical page id `p` refers to
the slice:
``k_cache[p * PAGE_SIZE : (p + 1) * PAGE_SIZE]``.
:param batch_mapping:
Tensor of shape ``[B]`` (int16 / int32) mapping the local batch
index used in this kernel launch to the global batch index used
to index ``global_page_table``. This allows the same global cache
to be shared across multiple microbatches.
:param cu_seqlens_q:
Tensor of shape ``[B + 1]`` (int32) with cumulative sequence
lengths for the *new* tokens (q/k/v) in packed form. For batch
element ``b``:
``seqlen_q[b] = cu_seqlens_q[b + 1] - cu_seqlens_q[b]``.
The total number of tokens satisfies
``N = cu_seqlens_q[-1]``.
:param max_seqlen_q:
Maximum new query sequence length across the batch, i.e.
``max_b seqlen_q[b]``.
:param max_seqlen_k_cache:
Maximum cached sequence length across (batch, KV head), i.e.
``max_{b,h} seq_lens_bh[b, h]``.
:param HKV:
Number of KV heads. Must divide ``HQ``.
:param PAGE_SIZE:
Number of tokens stored per physical page in the paged KV cache.
``CACHE_SIZE`` must be divisible by ``PAGE_SIZE``.
:param sm_scale:
Optional scaling factor applied to the attention logits before
softmax. If ``None``, defaults to ``1.0 / sqrt(D)``.
:returns torch.Tensor:
Attention output of shape ``[N, HQ, D]``, with the same dtype and
device as ``q``. The output is laid out in the same packed
varlen format as the input queries, i.e. the first
``seqlen_q[0]`` rows correspond to batch 0, the next
``seqlen_q[1]`` rows to batch 1, etc.
"""
assert
q
.
ndim
==
3
,
"q should be [N, HQ, D]"
N
,
HQ
,
D
=
q
.
shape
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be power of two"
B
=
cu_seqlens_q
.
numel
()
-
1
assert
B
>
0
assert
HQ
%
HKV
==
0
,
"Number of query heads must divide number of keys heads"
if
max_seqlen_k_cache
==
0
:
# Zero-prefix compressed prefill on DCU produced repeated-character output in
# the Triton on-band appended branch; use exact varlen FA for this subcase.
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
return
_causal_appended_only_exact
(
q
,
k
,
v
,
cu_seqlens_q
,
sm_scale
=
sm_scale
,
max_seqlen_q
=
max_seqlen_q
,
)
H_g
=
HQ
//
HKV
# view Q as [HKV, N, QUERY_GROUP_SIZE, D]
out
=
torch
.
empty_like
(
q
)
q
=
q
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
out
=
out
.
view
(
N
,
HKV
,
H_g
,
D
).
permute
(
1
,
0
,
2
,
3
)
# K_app/V_app: [N, HKV, D] -> [HKV, N, D]
k_app
=
k
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
v_app
=
v
.
view
(
N
,
HKV
,
D
).
permute
(
1
,
0
,
2
)
q
=
q
.
contiguous
()
out
=
out
.
contiguous
()
k_app
=
k_app
.
contiguous
()
v_app
=
v_app
.
contiguous
()
cu_seqlens_q
=
cu_seqlens_q
.
to
(
dtype
=
torch
.
int32
,
device
=
q
.
device
)
seq_lens_bh
=
seq_lens_bh
.
to
(
dtype
=
torch
.
int32
,
device
=
q
.
device
)
batch_mapping
=
batch_mapping
.
to
(
dtype
=
torch
.
int16
,
device
=
q
.
device
)
N_LOGICAL_PAGES_MAX
=
global_page_table
.
shape
[
-
1
]
CACHE_SIZE
=
k_cache
.
shape
[
0
]
assert
v_cache
.
shape
[
0
]
==
CACHE_SIZE
assert
k_cache
.
shape
[
1
]
==
D
and
v_cache
.
shape
[
1
]
==
D
assert
PAGE_SIZE
>
0
and
CACHE_SIZE
%
PAGE_SIZE
==
0
k_cache
=
k_cache
.
contiguous
()
v_cache
=
v_cache
.
contiguous
()
global_page_table
=
global_page_table
.
contiguous
()
if
sm_scale
is
None
:
sm_scale
=
1.0
/
math
.
sqrt
(
D
)
# strides for Q [G, N, QUERY_GROUP_SIZE, D]
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_Q_D
=
q
.
stride
()
STRIDE_KC
,
STRIDE_VC
=
k_cache
.
stride
(
0
),
v_cache
.
stride
(
0
)
# [G, N, D]
STRIDE_KA_G
,
STRIDE_KA_N
,
STRIDE_KA_D
=
k_app
.
stride
()
STRIDE_VA_G
,
STRIDE_VA_N
,
STRIDE_VA_D
=
v_app
.
stride
()
# OUT [G, N, QUERY_GROUP_SIZE, D]
STRIDE_OUT_G
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
STRIDE_OUT_D
=
out
.
stride
()
# launch grid
maybe_set_allocator
(
lambda
size
,
align
,
_
:
torch
.
empty
(
size
,
dtype
=
torch
.
int8
,
device
=
q
.
device
)
)
assert
STRIDE_KA_D
==
STRIDE_VA_D
==
STRIDE_Q_D
==
STRIDE_OUT_D
==
1
,
(
"final dimension must be contiguous"
)
def
grid
(
META
):
return
HKV
,
B
,
triton
.
cdiv
(
max_seqlen_q
,
META
[
"BLOCK_M"
])
# Autotune key must reflect the **total** K length seen by the kernel:
# cached prefix + appended tokens from the current prefill chunk.
#
# Using only `max_seqlen_k_cache` is wrong for the first compressed prefill
# step in `pdtriton`: the cache prefix is 0, but the kernel actually attends
# over the entire appended prompt (`seq_len_append`). On DCU this can cause
# Triton to autotune/select a kernel as if K==1 while executing on a long K,
# which has been observed to produce incorrect outputs. We still clamp to 1
# to avoid `next_power_of_2(0)`.
_k_max_autotune
=
max
(
int
(
max_seqlen_k_cache
)
+
int
(
max_seqlen_q
),
1
)
AUTOTUNE_MAX_Q_LEN
=
triton
.
next_power_of_2
(
max_seqlen_q
)
AUTOTUNE_MAX_K_LEN
=
triton
.
next_power_of_2
(
_k_max_autotune
)
_causal_head_sparse_varlen_with_cache
[
grid
](
Q
=
q
,
K_cache
=
k_cache
,
V_cache
=
v_cache
,
K_app
=
k_app
,
V_app
=
v_app
,
cu_seqlens_qk
=
cu_seqlens_q
,
seq_lens_bh
=
seq_lens_bh
,
page_table
=
global_page_table
,
batch_mapping
=
batch_mapping
,
OUT
=
out
,
HKV
=
HKV
,
QUERY_GROUP_SIZE
=
H_g
,
PAGE_SIZE
=
PAGE_SIZE
,
N_LOGICAL_PAGES_MAX
=
N_LOGICAL_PAGES_MAX
,
STRIDE_Q_G
=
STRIDE_Q_G
,
STRIDE_Q_N
=
STRIDE_Q_N
,
STRIDE_Q_H
=
STRIDE_Q_H
,
STRIDE_KC
=
STRIDE_KC
,
STRIDE_VC
=
STRIDE_VC
,
STRIDE_KA_G
=
STRIDE_KA_G
,
STRIDE_KA_N
=
STRIDE_KA_N
,
STRIDE_VA_G
=
STRIDE_VA_G
,
STRIDE_VA_N
=
STRIDE_VA_N
,
STRIDE_OUT_G
=
STRIDE_OUT_G
,
STRIDE_OUT_N
=
STRIDE_OUT_N
,
STRIDE_OUT_H
=
STRIDE_OUT_H
,
sm_scale
=
sm_scale
,
D
=
D
,
AUTOTUNE_MAX_Q_LEN
=
AUTOTUNE_MAX_Q_LEN
,
AUTOTUNE_MAX_K_LEN
=
AUTOTUNE_MAX_K_LEN
,
)
# permute breaks contiguity; view() requires a single contiguous span.
return
out
.
permute
(
1
,
0
,
2
,
3
).
reshape
(
N
,
HQ
,
D
)
autotune_configs_cc9
=
[
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
64
,
"WARPSPEC"
:
True
},
num_warps
=
16
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
64
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
64
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
True
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
32
,
"WARPSPEC"
:
False
},
num_warps
=
8
,
num_stages
=
4
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
8
,
num_stages
=
3
),
triton
.
Config
(
{
"BLOCK_N"
:
32
,
"BLOCK_M"
:
16
,
"WARPSPEC"
:
False
},
num_warps
=
4
,
num_stages
=
4
),
]
autotune_configs_cc8
=
[
triton
.
Config
(
{
"BLOCK_N"
:
BN
,
"BLOCK_M"
:
BM
,
"WARPSPEC"
:
True
},
num_warps
=
w
,
num_stages
=
s
)
for
BN
in
[
16
,
32
]
for
BM
in
[
64
]
for
w
in
[
4
,
8
]
for
s
in
[
2
,
3
]
]
def
prune_invalid_configs
(
configs
,
_
,
**
kwargs
):
return
[
conf
for
conf
in
configs
if
not
(
conf
.
kwargs
.
get
(
"BLOCK_N"
)
==
32
and
conf
.
kwargs
.
get
(
"num_stages"
)
==
4
)
]
def
get_autotune_configs
():
if
cuda_capability_geq
(
9
,
0
):
return
autotune_configs_cc9
else
:
return
autotune_configs_cc8
@
triton_autotune
(
configs
=
get_autotune_configs
(),
key
=
[
"HKV"
,
"QUERY_GROUP_SIZE"
,
"D"
,
"PAGE_SIZE"
,
"AUTOTUNE_MAX_K_LEN"
,
"AUTOTUNE_MAX_Q_LEN"
,
],
cache_results
=
True
,
)
@
triton
.
jit
def
_causal_head_sparse_varlen_with_cache
(
Q
,
# [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
K_cache
,
V_cache
,
# [CACHE_SIZE, D]
K_app
,
V_app
,
# [HKV, N, D]
cu_seqlens_qk
,
# [B+1]
seq_lens_bh
,
# [B, HKV]
page_table
,
# [B_total, HKV, N_LOGICAL_PAGES_MAX]
batch_mapping
,
# [B], maps local b -> global batch index
OUT
,
# [HKV, N, QUERY_GROUP_SIZE, D]
#
HKV
:
tl
.
constexpr
,
QUERY_GROUP_SIZE
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
,
STRIDE_Q_G
,
STRIDE_Q_N
,
STRIDE_Q_H
,
STRIDE_KC
,
STRIDE_VC
,
STRIDE_KA_G
,
STRIDE_KA_N
,
STRIDE_VA_G
,
STRIDE_VA_N
,
STRIDE_OUT_G
,
STRIDE_OUT_N
,
STRIDE_OUT_H
,
sm_scale
,
#
D
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
WARPSPEC
:
tl
.
constexpr
,
AUTOTUNE_MAX_Q_LEN
:
tl
.
constexpr
,
# used for autotune key
AUTOTUNE_MAX_K_LEN
:
tl
.
constexpr
,
# used for autotune key
):
TOTAL_N_QUERIES
:
tl
.
constexpr
=
BLOCK_M
*
QUERY_GROUP_SIZE
pid_g
=
tl
.
program_id
(
0
)
# kv_head id in [0, HKV)
pid_b
=
tl
.
program_id
(
1
)
# batch id
pid_m
=
tl
.
program_id
(
2
)
# query-tile id within batch
# batch segment [qb, qe) in N
off_b
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
)
off_b1
=
tl
.
load
(
cu_seqlens_qk
+
pid_b
+
1
)
seq_len_append
=
off_b1
-
off_b
q_start
=
off_b
+
pid_m
*
BLOCK_M
q_end
=
tl
.
minimum
(
q_start
+
BLOCK_M
,
off_b1
)
# number of queries in this tile for this batch
M
=
q_end
-
q_start
if
M
<=
0
:
return
# cached length for (b, kv_head=pid_g)
L_cache
=
tl
.
load
(
seq_lens_bh
+
pid_b
*
HKV
+
pid_g
)
# row indices flattened over [QUERY_GROUP_SIZE, M]
offs_row
=
tl
.
arange
(
0
,
TOTAL_N_QUERIES
)
row_m
=
offs_row
%
BLOCK_M
row_h
=
offs_row
//
BLOCK_M
# valid rows: only those with row_m < M
row_mask
=
row_m
<
M
# global query index per row
q_idx
=
q_start
+
row_m
offs_d
=
tl
.
arange
(
0
,
D
)
# Q tile: [TOTAL_N_QUERIES, D]
# Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
q_ptrs
=
(
Q
+
pid_g
*
STRIDE_Q_G
+
q_idx
[:,
None
]
*
STRIDE_Q_N
+
row_h
[:,
None
]
*
STRIDE_Q_H
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
row_mask
[:,
None
],
other
=
0.0
)
e_max
=
tl
.
zeros
([
TOTAL_N_QUERIES
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
TOTAL_N_QUERIES
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
TOTAL_N_QUERIES
,
D
],
dtype
=
tl
.
float32
)
offs_block_n
=
tl
.
arange
(
0
,
BLOCK_N
)
# Convert natural-log softmax scale into log2 domain for exp2-based updates.
# Use the full log2(e) constant; this is mathematically equivalent to exp and
# not the source of the zero-prefix bug, but avoids avoidable rounding loss.
qk_scale
=
sm_scale
*
1.4426950408889634
# 1) attend over cachee K/V
if
L_cache
>
0
:
# map local (b) to global batch index
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
pt_base
=
(
mapped_b
*
HKV
+
pid_g
)
*
N_LOGICAL_PAGES_MAX
# iterate logical pages
num_lp
=
tl
.
cdiv
(
L_cache
,
PAGE_SIZE
)
for
lp
in
tl
.
range
(
0
,
num_lp
):
# can overflow in 32 bits so upcast
phys
=
tl
.
load
(
page_table
+
pt_base
+
lp
).
to
(
tl
.
int64
)
page_start
=
phys
*
PAGE_SIZE
# how many valid tokens in this page for this (b,g)
remain
=
L_cache
-
lp
*
PAGE_SIZE
page_len
=
tl
.
minimum
(
PAGE_SIZE
,
remain
)
# iterate over this page in BLOCK_N chunks
for
ks
in
tl
.
range
(
0
,
page_len
,
BLOCK_N
):
offs_n
=
ks
+
offs_block_n
mask_n
=
offs_n
<
page_len
key_idx
=
page_start
+
offs_n
k_ptrs
=
K_cache
+
key_idx
[:,
None
]
*
STRIDE_KC
+
offs_d
[
None
,
:]
k
=
tl
.
load
(
k_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
# [BN, D]
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
# [TOTAL_N_QUERIES, BN]
qk
=
tl
.
where
(
row_mask
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
# softmax update
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
v_ptrs
=
V_cache
+
key_idx
[:,
None
]
*
STRIDE_VC
+
offs_d
[
None
,
:]
v
=
tl
.
load
(
v_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
# [BN, D]
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# 2) attend over appended K_app/V_app (causal)
# appended tokens for batch b are in [off_b, off_b1)
# query tile is [q_start, q_end)
# for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
if
q_end
>
off_b
:
# exactly one appended token
if
seq_len_append
==
1
:
ka_ptrs
=
K_app
+
pid_g
*
STRIDE_KA_G
+
off_b
*
STRIDE_KA_N
+
offs_d
k
=
tl
.
load
(
ka_ptrs
)
# [D]
qk
=
tl
.
sum
(
q
*
k
[
None
,
:],
1
)
*
qk_scale
qk
=
tl
.
where
(
row_mask
,
qk
,
-
1.0e6
)
n_e_max
=
tl
.
maximum
(
e_max
,
qk
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
)
va_ptrs
=
V_app
+
pid_g
*
STRIDE_VA_G
+
off_b
*
STRIDE_VA_N
+
offs_d
v
=
tl
.
load
(
va_ptrs
)
# [D]
acc
=
acc
*
re_scale
[:,
None
]
+
p
[:,
None
]
*
v
[
None
,
:]
e_sum
=
e_sum
*
re_scale
+
p
else
:
# off-band: k in [off_b, q_start)
# for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
# so no causal mask needed.
off_band_start
=
off_b
off_band_end
=
q_start
if
off_band_end
>
off_band_start
:
for
ks
in
tl
.
range
(
off_band_start
,
off_band_end
,
BLOCK_N
):
offs_n
=
ks
+
offs_block_n
mask_n
=
offs_n
<
off_band_end
ka_ptrs
=
(
K_app
+
pid_g
*
STRIDE_KA_G
+
offs_n
[:,
None
]
*
STRIDE_KA_N
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
ka_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
qk
=
tl
.
where
(
row_mask
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
va_ptrs
=
(
V_app
+
pid_g
*
STRIDE_VA_G
+
offs_n
[:,
None
]
*
STRIDE_VA_N
+
offs_d
[
None
,
:]
)
v
=
tl
.
load
(
va_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# on-band remaining k
on_band_start
=
tl
.
maximum
(
q_start
,
off_b
)
if
on_band_start
<
q_end
:
for
ks
in
tl
.
range
(
on_band_start
,
q_end
,
BLOCK_N
):
offs_n
=
ks
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
q_end
ka_ptrs
=
(
K_app
+
pid_g
*
STRIDE_KA_G
+
offs_n
[:,
None
]
*
STRIDE_KA_N
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
ka_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
.
T
)
*
qk_scale
# DCU/ROCm: using a single fused boolean expression here can lead
# to early query rows in the tile behaving as if they could attend
# to later appended keys in the same on-band block. That shows up
# as token-0 output deviating from V[0] while the last token in the
# batch remains almost exact. Apply the three masks explicitly.
#
# Use local positions within the current query tile for the causal
# relation: all off-band keys (< q_start) were already handled
# above, so the on-band block only needs a lower-triangular mask
# relative to q_start.
qk
=
tl
.
where
(
row_mask
[:,
None
],
qk
,
-
1.0e6
)
qk
=
tl
.
where
(
mask_n
[
None
,
:],
qk
,
-
1.0e6
)
local_q
=
row_m
local_k
=
offs_n
-
q_start
caus_mask
=
local_k
[
None
,
:]
<=
local_q
[:,
None
]
qk
=
tl
.
where
(
caus_mask
,
qk
,
-
1.0e6
)
cur_max
=
tl
.
max
(
qk
,
1
)
n_e_max
=
tl
.
maximum
(
e_max
,
cur_max
)
re_scale
=
tl
.
math
.
exp2
(
e_max
-
n_e_max
)
p
=
tl
.
math
.
exp2
(
qk
-
n_e_max
[:,
None
])
va_ptrs
=
(
V_app
+
pid_g
*
STRIDE_VA_G
+
offs_n
[:,
None
]
*
STRIDE_VA_N
+
offs_d
[
None
,
:]
)
v
=
tl
.
load
(
va_ptrs
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
acc
=
acc
*
re_scale
[:,
None
]
acc
=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
,
acc
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
# 3) write outputs
o
=
(
acc
/
e_sum
[:,
None
]).
to
(
q
.
dtype
)
out_ptrs
=
(
OUT
+
pid_g
*
STRIDE_OUT_G
+
q_idx
[:,
None
]
*
STRIDE_OUT_N
+
row_h
[:,
None
]
*
STRIDE_OUT_H
+
offs_d
[
None
,
:]
)
tl
.
store
(
out_ptrs
,
o
,
mask
=
row_mask
[:,
None
])
vllm/kvprune/benchmark/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark helpers for kv-prune / compactor kernels.
Upstream snapshot (``compactor-vllm/src/compactor_vllm/benchmark``) contained **only**
an empty ``__init__.py`` — no additional ``.py`` scripts. Those files are merged here
as-is; there is nothing else to list under that directory in upstream.
Use :data:`BENCHMARK_REGISTRY` to register microbenchmarks or CLI entrypoints you
add under ``vllm.kvprune.benchmark``.
"""
from
__future__
import
annotations
from
typing
import
Any
,
Callable
# Files copied from upstream ``compactor_vllm/benchmark/`` (relative to that dir).
UPSTREAM_BENCHMARK_FILES
:
tuple
[
str
,
...]
=
(
"__init__.py"
,)
# Optional: name -> benchmark callable or import path string (e.g. "mymod:main").
# Populated when you add real benchmarks beside this package.
BENCHMARK_REGISTRY
:
dict
[
str
,
Callable
[...,
Any
]
|
str
]
=
{}
def
list_upstream_benchmark_files
()
->
tuple
[
str
,
...]:
"""Return the list of filenames that existed in upstream ``benchmark/``."""
return
UPSTREAM_BENCHMARK_FILES
def
register_benchmark
(
name
:
str
,
target
:
Callable
[...,
Any
]
|
str
)
->
None
:
"""Register a benchmark by name (callable or ``"module:attr"`` import path)."""
BENCHMARK_REGISTRY
[
name
]
=
target
def
iter_registered_benchmarks
()
->
list
[
tuple
[
str
,
Callable
[...,
Any
]
|
str
]]:
"""Return ``(name, target)`` pairs from :data:`BENCHMARK_REGISTRY`."""
return
list
(
BENCHMARK_REGISTRY
.
items
())
__all__
=
[
"BENCHMARK_REGISTRY"
,
"UPSTREAM_BENCHMARK_FILES"
,
"iter_registered_benchmarks"
,
"list_upstream_benchmark_files"
,
"register_benchmark"
,
]
vllm/kvprune/compression/__init__.py
0 → 100644
View file @
2b7160c6
from
vllm.kvprune.compression.common
import
(
BaseCompressionMethod
,
NoCompression
,
)
from
vllm.kvprune.compression.criticalkv
import
CriticalAdaKVCompression
from
vllm.kvprune.compression.compactor
import
CompactorCompression
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
CompressionMethod
,
SequenceCompressionParams
,
)
from
vllm.kvprune.compression.snapkv
import
SnapKVCompression
COMPRESSION_REGISTRY
:
dict
[
CompressionMethod
,
type
[
BaseCompressionMethod
]]
=
{
CompressionMethod
.
CRITICALADAKV
:
CriticalAdaKVCompression
,
CompressionMethod
.
COMPACTOR
:
CompactorCompression
,
CompressionMethod
.
SNAPKV
:
SnapKVCompression
,
CompressionMethod
.
NONE
:
NoCompression
,
}
def
apply_prerope_compression
(
q
,
k
,
v
,
context
):
method
=
context
.
compression_context
.
compression_method
return
COMPRESSION_REGISTRY
[
method
].
pre_rope_scoring
(
q
,
k
,
v
,
context
=
context
)
def
apply_postrope_compression
(
q
,
k
,
v
,
prerope_scores
,
context
):
method
=
context
.
compression_context
.
compression_method
return
COMPRESSION_REGISTRY
[
method
].
post_rope_scoring
(
q
,
k
,
v
,
prerope_scores
,
context
=
context
)
__all__
=
[
"apply_prerope_compression"
,
"apply_postrope_compression"
,
"CompressionMethod"
,
"BatchCompressionParams"
,
"SequenceCompressionParams"
,
"COMPRESSION_REGISTRY"
]
vllm/kvprune/compression/common.py
0 → 100644
View file @
2b7160c6
from
abc
import
ABC
,
abstractmethod
import
os
from
typing
import
Optional
import
torch
from
vllm.kvprune.kv_cache.store_kv_cache
import
prefill_store_topk_kv
class
BaseCompressionMethod
(
ABC
):
"""
Abstract interface for KV cache compression methods.
A compression method is implemented as a pair of optional scoring phases
that run before and after rotary position embedding (RoPE) is applied:
1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
- refine / reweight the pre-RoPE scores, or
- compute potentially position-aware.
Concrete subclasses are expected to implement both
static methods and return a single tensor of scores (or ``None`` if the
phase is a no-op), which the caller can then feed into the shared
“scores → top-k indices → KV extraction” pipeline.
"""
@
staticmethod
@
abstractmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Compute per-token importance scores from pre-RoPE queries/keys.
Args:
:param q:
Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
vllm.kvprune.utils.context.Context object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
A tensor of scores (e.g. per-token, per-head importance values)
to be passed to ``post_rope_scoring`` or directly into the
top-k selection step. If this phase is a no-op, implementations
should return ``None``. Shape ``[total_tokens, HKV]```.
"""
pass
@
staticmethod
@
abstractmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
Optional
[
torch
.
Tensor
],
context
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Compute or refine importance scores from post-RoPE queries/keys.
This method is called after rotary embeddings have been applied. It can
optionally use both the post-RoPE Q/K and any scores produced by
``pre_rope_scoring`` to produce final scores used for token selection.
Common patterns include:
* Using ``pre_rope_scores`` as a base signal and applying a
position-aware correction.
* Only computing scores that depend on absolute or relative positions.
* Simply passing through ``pre_rope_scores`` unchanged.
Args:
:param q:
Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param pre_rope_scores:
Optional scores returned by ``pre_rope_scoring``. May be
``None`` if the pre-RoPE phase returned None.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
vllm.kvprune.utils.context.Context object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
Final importance scores to be consumed by the compression
pipeline (for top-k token selection). If this phase is a
no-op, implementations may return ``pre_rope_scores``. If
None is returned, no compression will be applied.
"""
pass
class
NoCompression
(
BaseCompressionMethod
):
"""
Trivial compression method that disables KV cache compression.
"""
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
return
None
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
return
pre_rope_scores
def
extract_and_store_top_kv
(
scores
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_k_len
:
int
,
top_k
:
int
,
H
:
int
,
new_keys
:
torch
.
Tensor
,
# [N_total, H, D]
new_vals
:
torch
.
Tensor
,
# [N_total, H, D]
num_tokens_to_retain
:
torch
.
Tensor
,
# [B] int32
page_table
:
torch
.
Tensor
,
# [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local -> true batch rows)
bh_lens
:
torch
.
Tensor
,
# [B, H] int32 (contiguous), UPDATED atomically
k_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
v_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE
:
int
,
PAD_TO_PAGE_SIZE
:
bool
=
True
,
K_TILE
:
int
=
16
,
padding
:
float
=
-
float
(
"inf"
),
):
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
# per_head: per-head highest-scoring remaining tokens for page padding.
# global_scan: legacy global ranking order, padded by scanning forward in-kernel.
padding_mode
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_PADDING_MODE"
,
"per_head"
).
strip
().
lower
()
max_pairs_per_batch
=
(
cu_seqlens_k
[
1
:]
-
cu_seqlens_k
[:
-
1
]).
to
(
device
=
num_tokens_to_retain
.
device
,
dtype
=
num_tokens_to_retain
.
dtype
)
*
H
num_tokens_to_retain
=
torch
.
minimum
(
num_tokens_to_retain
,
max_pairs_per_batch
)
indices_topk
,
candidate_counts
=
scores_to_retain_indices
(
scores
,
cu_seqlens_k
=
cu_seqlens_k
,
max_k_len
=
max_k_len
,
top_k
=
top_k
,
H
=
H
,
num_tokens_to_retain
=
num_tokens_to_retain
,
page_size
=
PAGE_SIZE
,
padding_mode
=
padding_mode
,
padding
=
padding
,
)
prefill_store_topk_kv
(
new_keys
=
new_keys
,
new_vals
=
new_vals
,
indices_topk
=
indices_topk
,
candidate_counts
=
candidate_counts
,
num_tokens_to_retain
=
num_tokens_to_retain
,
page_table
=
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
bh_lens
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
cu_seqlens_k
=
cu_seqlens_k
,
PAGE_SIZE
=
PAGE_SIZE
,
PAD_TO_PAGE_SIZE
=
PAD_TO_PAGE_SIZE
,
K_TILE
=
K_TILE
,
)
def
scores_to_retain_indices
(
scores
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_k_len
:
int
,
top_k
:
int
,
H
:
int
,
num_tokens_to_retain
:
torch
.
Tensor
,
page_size
:
int
,
padding_mode
:
str
=
"per_head"
,
padding
:
float
=
-
float
(
"inf"
),
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Build candidate token-head indices for compression writes.
For each batch element, this helper returns:
1. a prefix of the true global top-k ``(token, head)`` pairs, and
2. a suffix of additional padding candidates according to ``padding_mode``:
- ``per_head``: choose each head's highest-scoring remaining tokens.
- ``global_scan``: keep the legacy global ranking order and let the
store kernel scan forward until it finds enough entries for that head.
The page-alignment requirement comes from the paged KV cache, but the
padding candidates themselves do not need to be discovered inside the
Triton store kernel. Choosing them here avoids the older "scan the global
candidate list until you stumble across enough entries for this head"
behavior, which could distort the retained set even though the page-table
/ reclaim logic only cares about the final per-head counts.
Args:
:param scores:
Tensor of shape ``[N_total, HKV]`` containing scores for each
(token, head) pair in packed varlen format.
:param cu_seqlens_k:
Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
lengths for each batch element. The total number of tokens
satisfies ``N_total = cu_seqlens_k[-1]``.
:param max_k_len:
Maximum key sequence length across the batch (i.e.
``max_b seqlen_k[b]``). Used to allocate the padded buffer.
:param top_k:
Kept for API compatibility with the caller. The retained prefix is
determined by ``num_tokens_to_retain``; the tail is built from
per-head padding needs.
:param H:
Number of key heads; must match ``scores.shape[1]``.
:param num_tokens_to_retain:
The true number of token-head pairs to keep for each batch element
before page padding.
:param page_size:
Page size of the KV cache. Determines how many extra candidates
are needed per head to reach page alignment.
:param padding_mode:
``per_head`` for per-head optimal padding candidates, or
``global_scan`` for the legacy "scan the global ranking" behavior.
:param padding:
Kept for backward compatibility; no longer used.
Returns:
A tuple ``(indices, counts)`` where:
- ``indices`` is ``[B, MAX_SEL]`` int64, containing global flattened
``token * H + head`` indices.
- ``counts`` is ``[B]`` int32, the number of valid candidates for each
batch row inside ``indices``.
"""
del
max_k_len
,
top_k
,
padding
B
,
device
=
cu_seqlens_k
.
numel
()
-
1
,
scores
.
device
row_indices
:
list
[
torch
.
Tensor
]
=
[]
candidate_counts
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
device
)
if
padding_mode
not
in
(
"per_head"
,
"global_scan"
):
raise
ValueError
(
"Unsupported VLLM_KVPRUNE_PADDING_MODE. "
f
"Expected 'per_head' or 'global_scan', got
{
padding_mode
!
r
}
."
)
for
b
in
range
(
B
):
s
=
int
(
cu_seqlens_k
[
b
].
item
())
e
=
int
(
cu_seqlens_k
[
b
+
1
].
item
())
seq_len
=
e
-
s
total_pairs
=
seq_len
*
H
keep
=
min
(
int
(
num_tokens_to_retain
[
b
].
item
()),
total_pairs
)
if
total_pairs
==
0
or
keep
==
0
:
row_indices
.
append
(
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
device
))
continue
seq_scores
=
scores
[
s
:
e
,
:]
# [L, H]
flat_scores
=
seq_scores
.
reshape
(
-
1
)
if
padding_mode
==
"global_scan"
:
row
=
torch
.
argsort
(
flat_scores
,
dim
=
0
,
descending
=
True
)
else
:
prefix
=
torch
.
topk
(
flat_scores
,
k
=
keep
,
dim
=
0
,
largest
=
True
,
sorted
=
True
).
indices
selected_flat
=
torch
.
zeros
(
total_pairs
,
dtype
=
torch
.
bool
,
device
=
device
)
selected_flat
[
prefix
]
=
True
selected_mask
=
selected_flat
.
view
(
seq_len
,
H
)
head_counts
=
torch
.
bincount
(
prefix
%
H
,
minlength
=
H
)
need_per_head
=
(
page_size
-
(
head_counts
%
page_size
))
%
page_size
max_extra_per_head
=
seq_len
-
head_counts
need_per_head
=
torch
.
minimum
(
need_per_head
,
max_extra_per_head
)
tails
:
list
[
torch
.
Tensor
]
=
[]
for
h
in
range
(
H
):
need
=
int
(
need_per_head
[
h
].
item
())
if
need
<=
0
:
continue
rem_scores_h
=
seq_scores
[:,
h
].
masked_fill
(
selected_mask
[:,
h
],
-
torch
.
inf
)
tail_tok
=
torch
.
topk
(
rem_scores_h
,
k
=
need
,
dim
=
0
,
largest
=
True
,
sorted
=
True
).
indices
tails
.
append
(
tail_tok
*
H
+
h
)
if
tails
:
row
=
torch
.
cat
([
prefix
,
*
tails
],
dim
=
0
)
else
:
row
=
prefix
row_indices
.
append
(
row
+
s
*
H
)
candidate_counts
[
b
]
=
int
(
row
.
numel
())
max_sel
=
max
((
int
(
x
.
numel
())
for
x
in
row_indices
),
default
=
0
)
if
max_sel
==
0
:
return
(
torch
.
zeros
((
B
,
1
),
dtype
=
torch
.
int64
,
device
=
device
),
candidate_counts
,
)
indices
=
torch
.
zeros
((
B
,
max_sel
),
dtype
=
torch
.
int64
,
device
=
device
)
for
b
,
row
in
enumerate
(
row_indices
):
if
row
.
numel
():
indices
[
b
,
:
row
.
numel
()]
=
row
return
indices
,
candidate_counts
vllm/kvprune/compression/compactor.py
0 → 100644
View file @
2b7160c6
"""
Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
全局 z-score、blending 与首尾 sink pad)。
非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
杠杆分路径使用 batched ``torch.matmul``;在 transpose 与进入线性代数前对张量 ``.contiguous()``。
CUDA 上用 ``cholesky_solve``;在 HIP/ROCm 上对小的 sketch 维 ``k`` 用 ``linalg.inv(G+λI) @ X^T``
代替 ``cholesky_solve``,避开 rocBLAS TRSM 的 launch-bounds 告警与部分栈上的不稳定行为。
非因果 PyTorch 回退同理。
"""
from
__future__
import
annotations
import
math
from
typing
import
List
,
Optional
import
torch
import
triton
import
triton.language
as
tl
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
vllm.kvprune.compression.common
import
BaseCompressionMethod
from
vllm.kvprune.utils.helpers
import
maybe_execute_in_stream
def
resolve_kvpress_compactor_blending
(
compression_context
)
->
float
:
"""与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
if
compression_context
is
None
:
return
0.35
b
=
getattr
(
compression_context
,
"compactor_blending"
,
None
)
if
b
is
not
None
:
return
float
(
b
)
cr
=
getattr
(
compression_context
,
"compression_ratio"
,
None
)
if
cr
is
not
None
:
return
float
(
cr
)
return
0.35
class
CompactorCompression
(
BaseCompressionMethod
):
"""与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
chunk_size
:
int
=
256
@
staticmethod
def
pre_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
context
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
return
maybe_execute_in_stream
(
kvpress_leverage_scores_packed
,
k
,
context
.
cu_seqlens_q
,
compression_context
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
@
staticmethod
def
post_rope_scoring
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
context
,
)
->
Optional
[
torch
.
Tensor
]:
compression_context
=
context
.
compression_context
blending
=
resolve_kvpress_compactor_blending
(
compression_context
)
return
maybe_execute_in_stream
(
kvpress_compactor_post_rope
,
q
,
k
,
v
,
context
.
cu_seqlens_q
,
pre_rope_scores
,
compression_context
,
context
.
max_seqlen_q
,
chunk_size
=
CompactorCompression
.
chunk_size
,
blending
=
float
(
blending
),
STORE_STREAM
=
context
.
STORE_STREAM
,
)
# ---------------------------------------------------------------------------
# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
# ---------------------------------------------------------------------------
def
chol_with_jitter
(
G
:
torch
.
Tensor
,
jitter
:
float
=
0.0
,
max_tries
:
int
=
5
)
->
torch
.
Tensor
:
identity
=
torch
.
eye
(
G
.
shape
[
-
1
],
device
=
G
.
device
,
dtype
=
G
.
dtype
)
cur
=
float
(
jitter
)
for
_
in
range
(
max_tries
):
L
,
info
=
torch
.
linalg
.
cholesky_ex
(
(
G
+
cur
*
identity
).
contiguous
(),
upper
=
False
)
if
bool
((
info
==
0
).
all
()):
return
L
cur
=
max
(
1e-8
,
(
1e-2
if
cur
==
0.0
else
10.0
*
cur
))
raise
RuntimeError
(
f
"Cholesky failed after
{
max_tries
}
tries."
)
def
compute_leverage_scores_mid
(
key_states
:
torch
.
Tensor
,
sketch_dimension
:
int
)
->
torch
.
Tensor
:
"""
与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
返回 ``[L, H]``(未 z-score)。
维序与 kvpress 的 ``(B, H, S, D)`` 对齐;batched GEMM + ``.contiguous()`` 以利于后端库。
"""
d
,
k
=
key_states
.
shape
[
-
1
],
sketch_dimension
device
,
dtype
=
key_states
.
device
,
key_states
.
dtype
H
=
key_states
.
shape
[
1
]
Phi
=
torch
.
randn
(
1
,
H
,
d
,
k
,
device
=
device
,
dtype
=
dtype
)
*
(
1.0
/
math
.
sqrt
(
k
))
X0
=
key_states
.
transpose
(
0
,
1
).
unsqueeze
(
0
).
contiguous
()
X
=
(
X0
-
X0
.
mean
(
dim
=-
2
,
keepdim
=
True
)).
contiguous
()
Phi
=
Phi
.
contiguous
()
X
=
torch
.
matmul
(
X
,
Phi
).
to
(
torch
.
float32
).
contiguous
()
XT
=
X
.
transpose
(
-
2
,
-
1
).
contiguous
()
G
=
torch
.
matmul
(
XT
,
X
)
G_sym
=
0.5
*
(
G
+
G
.
transpose
(
-
2
,
-
1
)).
contiguous
()
# HIP: avoid batched cholesky_solve -> rocBLAS TRSM (launch_bounds noise / edge cases).
# k is sketch_dim (typically modest); inv is O(k^3) but batched over heads.
if
torch
.
version
.
hip
is
not
None
:
kk
=
G_sym
.
shape
[
-
1
]
eye
=
torch
.
eye
(
kk
,
device
=
G_sym
.
device
,
dtype
=
G_sym
.
dtype
,
requires_grad
=
False
)
G_reg
=
G_sym
+
1e-2
*
eye
inv_Xt
=
torch
.
linalg
.
inv
(
G_reg
)
@
XT
else
:
L_mat
=
chol_with_jitter
(
G_sym
,
jitter
=
1e-2
,
max_tries
=
5
)
inv_Xt
=
torch
.
cholesky_solve
(
XT
,
L_mat
,
upper
=
False
)
inv_Xt_T
=
inv_Xt
.
transpose
(
-
2
,
-
1
).
contiguous
()
scores
=
(
X
*
inv_Xt_T
).
sum
(
dim
=-
1
).
clamp_min
(
0
)
return
scores
.
squeeze
(
0
).
transpose
(
0
,
1
).
contiguous
()
def
kvpress_leverage_scores_packed
(
key_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
compression_ctx
,
)
->
torch
.
Tensor
:
device
=
key_states
.
device
N
,
Hkv
,
_D
=
key_states
.
shape
sketch_dim
=
int
(
getattr
(
compression_ctx
,
"sketch_dimension"
,
48
))
sink_start
=
int
(
getattr
(
compression_ctx
,
"sink_size_start"
,
8
))
sink_end
=
int
(
getattr
(
compression_ctx
,
"sink_size_end"
,
4
))
out
=
torch
.
zeros
(
N
,
Hkv
,
device
=
device
,
dtype
=
torch
.
float32
)
mids_flat
:
list
[
torch
.
Tensor
]
=
[]
mid_ranges
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
k_mid
=
key_states
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
raw
=
compute_leverage_scores_mid
(
k_mid
,
sketch_dim
)
mids_flat
.
append
(
raw
.
reshape
(
-
1
))
mid_ranges
.
append
((
mid_start
,
mid_end
,
Hkv
))
if
not
mids_flat
:
return
out
flat
=
torch
.
cat
(
mids_flat
,
dim
=
0
)
z
=
_zscore_flat_f32_global
(
flat
)
offset
=
0
for
(
mid_start
,
mid_end
,
_Hkv
),
r
in
zip
(
mid_ranges
,
mids_flat
):
n
=
r
.
numel
()
seg
=
z
[
offset
:
offset
+
n
].
view
(
mid_end
-
mid_start
,
Hkv
)
out
[
mid_start
:
mid_end
,
:]
=
seg
offset
+=
n
return
out
# ---------------------------------------------------------------------------
# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
# ---------------------------------------------------------------------------
def
_non_causal_chunked_attn_pytorch
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""参考实现:与 kvpress 逐算子一致。"""
assert
chunk_size
>
0
and
q
.
shape
==
k
.
shape
L
,
H
,
d
=
q
.
shape
B
=
1
q
=
q
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
k
=
k
.
permute
(
1
,
0
,
2
).
unsqueeze
(
0
).
contiguous
()
_B
,
H
,
S
,
_d
=
k
.
shape
S_pad
=
math
.
ceil
(
S
/
chunk_size
)
*
chunk_size
pad_len
=
S_pad
-
S
if
pad_len
>
0
:
q_padded
=
torch
.
cat
(
[
q
,
torch
.
zeros
(
B
,
H
,
pad_len
,
d
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)],
dim
=
2
)
k_padded
=
torch
.
cat
(
[
k
,
torch
.
zeros
(
B
,
H
,
pad_len
,
d
,
device
=
k
.
device
,
dtype
=
k
.
dtype
)],
dim
=
2
)
last_chunk_start
=
(
S
//
chunk_size
)
*
chunk_size
in_valid
=
torch
.
arange
(
last_chunk_start
,
S_pad
,
device
=
q
.
device
)
>=
S
query_mask
=
key_mask
=
in_valid
.
view
(
1
,
1
,
chunk_size
).
expand
(
B
,
H
,
chunk_size
)
else
:
q_padded
,
k_padded
=
q
,
k
last_chunk_start
=
((
S
-
1
)
//
chunk_size
)
*
chunk_size
in_valid
=
torch
.
arange
(
last_chunk_start
,
S_pad
,
device
=
q
.
device
)
>=
S
query_mask
=
key_mask
=
in_valid
.
view
(
1
,
1
,
chunk_size
).
expand
(
B
,
H
,
chunk_size
)
num_chunks
=
S_pad
//
chunk_size
q_chunks
=
q_padded
.
contiguous
().
view
(
B
,
H
,
num_chunks
,
chunk_size
,
d
)
k_chunks
=
k_padded
.
contiguous
().
view
(
B
,
H
,
num_chunks
,
chunk_size
,
d
)
dots
=
torch
.
matmul
(
q_chunks
,
k_chunks
.
transpose
(
-
2
,
-
1
).
contiguous
()
)
dots
[:,
:,
-
1
].
masked_fill_
(
query_mask
.
unsqueeze
(
-
1
),
0
)
dots
[:,
:,
-
1
].
masked_fill_
(
key_mask
.
unsqueeze
(
-
2
),
-
1e-9
)
attn
=
torch
.
softmax
(
dots
.
to
(
torch
.
float32
),
dim
=-
1
)
out
=
attn
.
sum
(
dim
=-
2
).
view
(
B
,
H
,
S_pad
)[...,
:
S
]
return
out
.
squeeze
(
0
).
transpose
(
0
,
1
).
contiguous
()
@
triton
.
jit
def
_non_causal_chunk_row_kernel
(
Q_ptr
,
K_ptr
,
Out_ptr
,
stride_qh
,
stride_qs
,
stride_qd
,
stride_kh
,
stride_ks
,
stride_kd
,
stride_oh
,
stride_os
,
S
,
S_pad
,
num_chunks
,
CHUNK_SIZE
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
ND
:
tl
.
constexpr
,
):
"""
每个 program:一个 head、一个 chunk、一条 query 行。
对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
"""
h
=
tl
.
program_id
(
0
)
c
=
tl
.
program_id
(
1
)
iq
=
tl
.
program_id
(
2
)
g_i
=
c
*
CHUNK_SIZE
+
iq
offs_j
=
tl
.
arange
(
0
,
CHUNK_SIZE
)
logits
=
tl
.
zeros
([
CHUNK_SIZE
],
dtype
=
tl
.
float32
)
for
db
in
range
(
ND
):
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
+
db
*
BLOCK_D
mask_d
=
offs_d
<
D
q_off
=
(
h
*
stride_qh
+
g_i
*
stride_qs
+
offs_d
*
stride_qd
)
qd
=
tl
.
load
(
Q_ptr
+
q_off
,
mask
=
mask_d
,
other
=
0.0
).
to
(
tl
.
float32
)
g_j
=
c
*
CHUNK_SIZE
+
offs_j
k_row_off
=
h
*
stride_kh
+
g_j
[:,
None
]
*
stride_ks
+
offs_d
[
None
,
:]
*
stride_kd
kj
=
tl
.
load
(
K_ptr
+
k_row_off
,
mask
=
mask_d
[
None
,
:],
other
=
0.0
).
to
(
tl
.
float32
)
logits
+=
tl
.
sum
(
qd
[
None
,
:]
*
kj
,
axis
=
1
)
row_invalid
=
g_i
>=
S
g_j_all
=
c
*
CHUNK_SIZE
+
offs_j
col_invalid
=
g_j_all
>=
S
logits
=
tl
.
where
(
row_invalid
,
tl
.
zeros
([
CHUNK_SIZE
],
dtype
=
tl
.
float32
),
logits
)
logits
=
tl
.
where
(
row_invalid
,
logits
,
tl
.
where
(
col_invalid
,
tl
.
full
([
CHUNK_SIZE
],
-
1e-9
,
dtype
=
tl
.
float32
),
logits
),
)
m
=
tl
.
max
(
logits
)
logits
=
logits
-
m
exp_v
=
tl
.
exp
(
logits
)
denom
=
tl
.
sum
(
exp_v
)
p
=
exp_v
/
denom
out_base
=
h
*
stride_oh
+
g_j_all
*
stride_os
tl
.
atomic_add
(
Out_ptr
+
out_base
,
p
,
mask
=
g_j_all
<
S
)
def
_non_causal_chunked_attn_triton
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
assert
q
.
is_cuda
and
k
.
is_cuda
and
q
.
shape
==
k
.
shape
L
,
H
,
d
=
q
.
shape
assert
chunk_size
>
0
S_pad
=
math
.
ceil
(
L
/
chunk_size
)
*
chunk_size
pad_len
=
S_pad
-
L
if
pad_len
>
0
:
zq
=
torch
.
zeros
(
pad_len
,
H
,
d
,
device
=
q
.
device
,
dtype
=
q
.
dtype
,
requires_grad
=
False
)
zk
=
torch
.
zeros
(
pad_len
,
H
,
d
,
device
=
k
.
device
,
dtype
=
k
.
dtype
,
requires_grad
=
False
)
q
=
torch
.
cat
([
q
,
zq
],
dim
=
0
)
k
=
torch
.
cat
([
k
,
zk
],
dim
=
0
)
Q
=
q
.
transpose
(
0
,
1
).
contiguous
().
to
(
dtype
=
torch
.
float32
)
K
=
k
.
transpose
(
0
,
1
).
contiguous
().
to
(
dtype
=
torch
.
float32
)
num_chunks
=
S_pad
//
chunk_size
out_acc
=
torch
.
zeros
(
H
,
S_pad
,
device
=
q
.
device
,
dtype
=
torch
.
float32
)
S
=
int
(
L
)
grid
=
(
H
,
num_chunks
,
chunk_size
)
BLOCK_D
=
32
if
d
<=
128
else
64
ND
=
(
d
+
BLOCK_D
-
1
)
//
BLOCK_D
_non_causal_chunk_row_kernel
[
grid
](
Q
,
K
,
out_acc
,
Q
.
stride
(
0
),
Q
.
stride
(
1
),
Q
.
stride
(
2
),
K
.
stride
(
0
),
K
.
stride
(
1
),
K
.
stride
(
2
),
out_acc
.
stride
(
0
),
out_acc
.
stride
(
1
),
S
,
S_pad
,
int
(
num_chunks
),
CHUNK_SIZE
=
chunk_size
,
D
=
d
,
BLOCK_D
=
BLOCK_D
,
ND
=
ND
,
num_warps
=
4
,
)
return
out_acc
[:,
:
S
].
transpose
(
0
,
1
).
contiguous
()
def
non_causal_chunked_attn
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
chunk_size
:
int
)
->
torch
.
Tensor
:
"""q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
if
q
.
is_cuda
and
k
.
is_cuda
:
return
_non_causal_chunked_attn_triton
(
q
,
k
,
chunk_size
)
return
_non_causal_chunked_attn_pytorch
(
q
,
k
,
chunk_size
)
# ---------------------------------------------------------------------------
# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
# ---------------------------------------------------------------------------
@
triton
.
jit
def
_mul_vnorm_avgpool3_kernel
(
A_ptr
,
V_ptr
,
OUT_ptr
,
stride_al
,
stride_ah
,
stride_vl
,
stride_vh
,
stride_vd
,
stride_ol
,
stride_oh
,
L
,
D
:
tl
.
constexpr
,
):
"""Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
l
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
offs
=
tl
.
arange
(
0
,
D
)
pos_m1
=
l
-
1
inb_m1
=
(
pos_m1
>=
0
)
&
(
pos_m1
<
L
)
ps_m1
=
tl
.
where
(
inb_m1
,
pos_m1
,
0
)
a_m1
=
tl
.
load
(
A_ptr
+
ps_m1
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_m1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v_m1
=
tl
.
load
(
V_ptr
+
ps_m1
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_m1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_m1
=
tl
.
where
(
inb_m1
,
a_m1
*
tl
.
sqrt
(
tl
.
sum
(
v_m1
*
v_m1
)),
0.0
)
inb_0
=
(
l
>=
0
)
&
(
l
<
L
)
ps0
=
tl
.
where
(
inb_0
,
l
,
0
)
a0
=
tl
.
load
(
A_ptr
+
ps0
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_0
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v0
=
tl
.
load
(
V_ptr
+
ps0
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_0
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_0
=
tl
.
where
(
inb_0
,
a0
*
tl
.
sqrt
(
tl
.
sum
(
v0
*
v0
)),
0.0
)
pos_p1
=
l
+
1
inb_p1
=
(
pos_p1
>=
0
)
&
(
pos_p1
<
L
)
ps_p1
=
tl
.
where
(
inb_p1
,
pos_p1
,
0
)
a_p1
=
tl
.
load
(
A_ptr
+
ps_p1
*
stride_al
+
h
*
stride_ah
,
mask
=
inb_p1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
v_p1
=
tl
.
load
(
V_ptr
+
ps_p1
*
stride_vl
+
h
*
stride_vh
+
offs
*
stride_vd
,
mask
=
inb_p1
,
other
=
0.0
,
).
to
(
tl
.
float32
)
s_p1
=
tl
.
where
(
inb_p1
,
a_p1
*
tl
.
sqrt
(
tl
.
sum
(
v_p1
*
v_p1
)),
0.0
)
out
=
(
s_m1
+
s_0
+
s_p1
)
*
(
1.0
/
3.0
)
tl
.
store
(
OUT_ptr
+
l
*
stride_ol
+
h
*
stride_oh
,
out
)
def
_mul_vnorm_avgpool3_fused
(
a
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
assert
a
.
dim
()
==
2
and
v
.
dim
()
==
3
and
a
.
shape
[
0
]
==
v
.
shape
[
0
]
and
a
.
shape
[
1
]
==
v
.
shape
[
1
]
L
,
H
,
D
=
v
.
shape
a
=
a
.
contiguous
()
v
=
v
.
contiguous
()
if
a
.
dtype
!=
torch
.
float32
:
a
=
a
.
float
()
if
out
is
None
:
out
=
torch
.
empty
((
L
,
H
),
device
=
v
.
device
,
dtype
=
torch
.
float32
)
if
L
==
0
or
H
==
0
:
return
out
grid
=
(
L
,
H
)
_mul_vnorm_avgpool3_kernel
[
grid
](
a
,
v
,
out
,
a
.
stride
(
0
),
a
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
out
.
stride
(
0
),
out
.
stride
(
1
),
L
,
D
=
D
,
num_warps
=
4
,
)
return
out
def
_maybe_mul_vnorm_avgpool3_fused
(
a
:
torch
.
Tensor
,
v
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
a
.
is_cuda
or
not
v
.
is_cuda
:
import
torch.nn.functional
as
F
s
=
a
*
v
.
norm
(
dim
=-
1
)
return
(
F
.
avg_pool1d
(
s
.
transpose
(
0
,
1
).
unsqueeze
(
0
),
kernel_size
=
3
,
padding
=
1
,
stride
=
1
)
.
squeeze
(
0
)
.
transpose
(
0
,
1
)
)
return
_mul_vnorm_avgpool3_fused
(
a
,
v
)
@
triton
.
jit
def
_zscore_elem_1d_kernel
(
X_ptr
,
OUT_ptr
,
n
,
mean
,
inv_std
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
offs
<
n
x
=
tl
.
load
(
X_ptr
+
offs
,
mask
=
mask
,
other
=
0.0
)
tl
.
store
(
OUT_ptr
+
offs
,
(
x
-
mean
)
*
inv_std
,
mask
=
mask
)
def
_zscore_flat_f32_global
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
"""
if
x
.
numel
()
==
0
:
return
x
mu
=
x
.
mean
()
sig
=
x
.
std
().
clamp_min
(
1e-6
)
inv
=
1.0
/
sig
if
not
x
.
is_cuda
:
return
(
x
-
mu
)
*
inv
x
=
x
.
contiguous
()
out
=
torch
.
empty_like
(
x
)
n
=
x
.
numel
()
BLOCK
=
1024
grid
=
(
triton
.
cdiv
(
n
,
BLOCK
),)
_zscore_elem_1d_kernel
[
grid
](
x
,
out
,
n
,
float
(
mu
.
item
()),
float
(
inv
.
item
()),
BLOCK
=
BLOCK
,
num_warps
=
4
,
)
return
out
def
_attn_scores_kvpress_middle
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
sink_start
:
int
,
sink_end
:
int
,
chunk_size
:
int
,
do_zscore
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
N
,
HQ
,
D
=
q
.
shape
Hkv
=
k
.
shape
[
1
]
G
=
HQ
//
Hkv
device
=
q
.
device
attn_out
=
torch
.
zeros
(
N
,
Hkv
,
device
=
device
,
dtype
=
torch
.
float32
)
parts
:
list
[
torch
.
Tensor
]
=
[]
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
q_m
=
q
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
k_m
=
k
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
v_m
=
v
[
mid_start
:
mid_end
,
:,
:].
contiguous
()
# HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
k_4d
=
k_m
.
unsqueeze
(
0
).
transpose
(
1
,
2
).
contiguous
()
# [1, Hkv, Lm, D]
k_rep
=
repeat_kv
(
k_4d
,
G
)[
0
].
transpose
(
0
,
1
).
contiguous
()
# [Lm, HQ, D]
A
=
non_causal_chunked_attn
(
q_m
,
k_rep
,
chunk_size
)
Lm
,
HQa
=
A
.
shape
assert
HQa
==
HQ
A
=
A
.
view
(
Lm
,
Hkv
,
G
).
mean
(
dim
=-
1
)
scores
=
_maybe_mul_vnorm_avgpool3_fused
(
A
,
v_m
)
parts
.
append
(
scores
.
reshape
(
-
1
))
if
not
parts
:
return
attn_out
flat_a
=
torch
.
cat
(
parts
,
dim
=
0
)
if
do_zscore
:
z_a
=
_zscore_flat_f32_global
(
flat_a
)
else
:
z_a
=
flat_a
offset
=
0
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
n
=
(
mid_end
-
mid_start
)
*
Hkv
attn_out
[
mid_start
:
mid_end
,
:]
=
z_a
[
offset
:
offset
+
n
].
view
(
mid_end
-
mid_start
,
Hkv
)
offset
+=
n
return
attn_out
def
non_causal_attn_scores
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_qk
:
torch
.
Tensor
,
max_seqlen_qk
:
int
,
chunk_size
:
int
,
sm_scale
:
float
=
None
,
normalize
:
bool
=
True
,
context_lens
:
Optional
[
List
[
int
]]
=
None
,
protected_first_tokens
:
Optional
[
List
[
int
]]
=
None
,
protected_last_tokens
:
Optional
[
List
[
int
]]
=
None
,
*
,
accum_scores
:
torch
.
Tensor
=
None
,
accum_blending
:
float
=
None
,
)
->
torch
.
Tensor
:
"""
与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
"""
del
sm_scale
,
max_seqlen_qk
sink_start
,
sink_end
=
8
,
4
out
=
_attn_scores_kvpress_middle
(
q
,
k
,
v
,
cu_seqlens_qk
,
sink_start
,
sink_end
,
chunk_size
,
do_zscore
=
normalize
,
)
if
accum_scores
is
not
None
:
w
=
0.5
if
accum_blending
is
None
else
float
(
accum_blending
)
out
=
out
+
w
*
accum_scores
.
to
(
device
=
out
.
device
,
dtype
=
out
.
dtype
)
if
protected_first_tokens
is
not
None
and
protected_last_tokens
is
not
None
and
context_lens
:
start
=
0
for
first
,
last
,
Lc
in
zip
(
protected_first_tokens
,
protected_last_tokens
,
context_lens
):
out
[
start
:
start
+
int
(
first
)].
fill_
(
torch
.
inf
)
out
[
start
+
int
(
Lc
)
-
int
(
last
)
:
start
+
int
(
Lc
)].
fill_
(
torch
.
inf
)
start
+=
int
(
Lc
)
return
out
def
kvpress_compactor_post_rope
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
pre_rope_scores
:
torch
.
Tensor
,
compression_ctx
,
max_seqlen_q
:
int
,
chunk_size
:
int
,
blending
:
float
,
)
->
torch
.
Tensor
:
del
max_seqlen_q
Hkv
=
k
.
shape
[
1
]
device
=
q
.
device
sink_start
=
int
(
getattr
(
compression_ctx
,
"sink_size_start"
,
8
))
sink_end
=
int
(
getattr
(
compression_ctx
,
"sink_size_end"
,
4
))
context_lens
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"context_lens"
,
None
)
protected_first
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"protected_first_tokens"
,
None
)
protected_last
:
Optional
[
List
[
int
]]
=
getattr
(
compression_ctx
,
"protected_last_tokens"
,
None
)
attn_out
=
_attn_scores_kvpress_middle
(
q
,
k
,
v
,
cu_seqlens
,
sink_start
,
sink_end
,
chunk_size
)
lev
=
pre_rope_scores
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
blended
=
torch
.
zeros_like
(
lev
)
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
mid_start
>=
mid_end
:
continue
blended
[
mid_start
:
mid_end
,
:]
=
(
blending
*
lev
[
mid_start
:
mid_end
,
:]
+
attn_out
[
mid_start
:
mid_end
,
:]
)
pad_val
=
blended
.
max
()
if
not
torch
.
isfinite
(
pad_val
)
or
pad_val
==
0
:
pad_val
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
cu_seqlens
.
numel
()
-
1
):
k_beg
=
int
(
cu_seqlens
[
b
].
item
())
k_end
=
int
(
cu_seqlens
[
b
+
1
].
item
())
L
=
k_end
-
k_beg
if
L
==
0
:
continue
left_keep
=
min
(
sink_start
,
L
)
right_keep
=
min
(
sink_end
,
max
(
0
,
L
-
left_keep
))
mid_start
=
k_beg
+
left_keep
mid_end
=
k_end
-
right_keep
if
left_keep
>
0
:
blended
[
k_beg
:
mid_start
,
:]
=
pad_val
if
right_keep
>
0
:
blended
[
mid_end
:
k_end
,
:]
=
pad_val
if
protected_first
is
not
None
and
protected_last
is
not
None
and
context_lens
:
start
=
0
for
first
,
last
,
Lc
in
zip
(
protected_first
,
protected_last
,
context_lens
):
blended
[
start
:
start
+
int
(
first
)].
fill_
(
torch
.
inf
)
blended
[
start
+
int
(
Lc
)
-
int
(
last
)
:
start
+
int
(
Lc
)].
fill_
(
torch
.
inf
)
start
+=
int
(
Lc
)
return
blended
Prev
1
…
4
5
6
7
8
9
10
11
12
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment