Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d29c39ca
Commit
d29c39ca
authored
Apr 30, 2026
by
chenzk
Browse files
vllm kvprune wo:v1.1.0
parent
f81ce56b
Changes
246
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
434 additions
and
0 deletions
+434
-0
vllm/kvprune_legacy_save/utils/kv_dist.py
vllm/kvprune_legacy_save/utils/kv_dist.py
+35
-0
vllm/kvprune_legacy_save/utils/layout_bridge.py
vllm/kvprune_legacy_save/utils/layout_bridge.py
+167
-0
vllm/kvprune_legacy_save/utils/sequence.py
vllm/kvprune_legacy_save/utils/sequence.py
+83
-0
vllm/kvprune_legacy_save/utils/tp_collectives.py
vllm/kvprune_legacy_save/utils/tp_collectives.py
+48
-0
vllm/kvprune_legacy_save/utils/tp_utils.py
vllm/kvprune_legacy_save/utils/tp_utils.py
+40
-0
vllm/kvprune_legacy_save/utils/triton_compat.py
vllm/kvprune_legacy_save/utils/triton_compat.py
+61
-0
No files found.
vllm/kvprune_legacy_save/utils/kv_dist.py
0 → 100644
View file @
d29c39ca
"""Distributed helpers for kvprune when embedded in vLLM (use TP process group)."""
from
__future__
import
annotations
import
torch
import
torch.distributed
as
dist
def
broadcast_from_tp_rank0
(
tensor
:
torch
.
Tensor
,
*
,
use_tp_group
:
bool
)
->
None
:
"""Broadcast ``tensor`` from group-local rank 0.
When ``use_tp_group`` is False (standalone compactor subprocesses), uses the
default process group (world == tensor parallel size).
When True (embedded in a vLLM worker), uses vLLM's tensor-parallel group so
collectives do not accidentally involve DP/PP ranks if the default group is global.
"""
if
not
use_tp_group
:
dist
.
broadcast
(
tensor
,
src
=
0
)
return
from
vllm.distributed.parallel_state
import
get_tp_group
get_tp_group
().
broadcast
(
tensor
,
src
=
0
)
def
barrier_sync
(
*
,
use_tp_group
:
bool
)
->
None
:
"""Barrier across either the default group or the TP group (see :func:`broadcast_from_tp_rank0`)."""
if
not
use_tp_group
:
dist
.
barrier
()
return
from
vllm.distributed.parallel_state
import
get_tp_group
get_tp_group
().
barrier
()
vllm/kvprune_legacy_save/utils/layout_bridge.py
0 → 100644
View file @
d29c39ca
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Bridge vLLM paged KV layout to compactor Triton kernels.
vLLM FlashAttention KV cache is shaped
[num_blocks, block_size, num_kv_heads, head_dim].
Compactor kernels expect a flat buffer [CACHE_SIZE, head_dim] and a page table
global_page_table[batch, kv_head, logical_page] -> physical_page_id
where each physical page holds ``block_size`` consecutive rows belonging to that
KV head only.
When num_kv_heads == 1 (MQA), a vLLM block maps 1:1 to compactor rows:
row_index = physical_block_id * block_size + offset_in_block.
When ``num_kv_heads > 1``, we permute to head-major
``[num_kv_heads, num_blocks, block_size, head_dim]`` and flatten to
``[num_kv_heads * num_blocks * block_size, head_dim]`` so each KV head occupies
a disjoint row range in the flat buffer. The page table is built so each
logical compression page maps to ``global_row // PAGE_SIZE`` in that layout
(see ``build_page_table_head_major``).
"""
from
__future__
import
annotations
import
torch
def
_cdiv
(
n
:
int
,
d
:
int
)
->
int
:
return
(
n
+
d
-
1
)
//
d
def
flatten_kv_cache_head_major
(
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""View ``[nb, bs, H, D]`` caches as ``[H*nb*bs, D]`` in head-major order."""
if
key_cache
.
shape
!=
value_cache
.
shape
:
raise
ValueError
(
"key_cache and value_cache must match"
)
nb
,
bs
,
hkv
,
d
=
key_cache
.
shape
k_hm
=
key_cache
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
v_hm
=
value_cache
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
k_flat
=
k_hm
.
reshape
(
hkv
*
nb
*
bs
,
d
)
v_flat
=
v_hm
.
reshape
(
hkv
*
nb
*
bs
,
d
)
return
k_flat
,
v_flat
def
write_head_major_flat_to_interleaved
(
k_flat
:
torch
.
Tensor
,
v_flat
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
)
->
None
:
"""Copy ``[H*nb*bs, D]`` head-major flats back to ``[nb, bs, H, D]``."""
nb
,
bs
,
hkv
,
d
=
key_cache
.
shape
k_hm
=
k_flat
.
view
(
hkv
,
nb
,
bs
,
d
)
v_hm
=
v_flat
.
view
(
hkv
,
nb
,
bs
,
d
)
key_cache
.
copy_
(
k_hm
.
permute
(
1
,
2
,
0
,
3
))
value_cache
.
copy_
(
v_hm
.
permute
(
1
,
2
,
0
,
3
))
def
build_page_table_head_major
(
block_table
:
torch
.
Tensor
,
num_kv_heads
:
int
,
num_blocks
:
int
,
block_size
:
int
,
page_size
:
int
,
max_batches
:
int
,
)
->
torch
.
Tensor
:
"""Build ``[max_batches, H, max_chain]`` page table for head-major flat KV.
Chains physical page ids in ``block_table`` order for each (batch, head).
Each entry is ``global_row // page_size`` where ``global_row`` indexes rows
in the head-major flat buffer (see ``flatten_kv_cache_head_major``).
"""
bsz
,
max_blocks
=
block_table
.
shape
if
bsz
>
max_batches
:
raise
ValueError
(
"batch size exceeds max_batches for page table"
)
num_pages_per_block
=
_cdiv
(
block_size
,
page_size
)
max_chain
=
max_blocks
*
num_pages_per_block
out
=
torch
.
zeros
(
(
max_batches
,
num_kv_heads
,
max_chain
),
dtype
=
torch
.
int32
,
device
=
block_table
.
device
,
)
bt
=
block_table
.
to
(
torch
.
int64
)
for
b
in
range
(
bsz
):
for
h
in
range
(
num_kv_heads
):
lp_idx
=
0
for
blk_i
in
range
(
max_blocks
):
bid
=
int
(
bt
[
b
,
blk_i
].
item
())
if
bid
<
0
:
continue
if
bid
>=
num_blocks
:
raise
ValueError
(
f
"block_table[
{
b
}
,
{
blk_i
}
]=
{
bid
}
out of range "
f
"num_blocks=
{
num_blocks
}
"
)
base_row
=
h
*
(
num_blocks
*
block_size
)
+
bid
*
block_size
for
p
in
range
(
num_pages_per_block
):
start_row
=
base_row
+
p
*
page_size
if
start_row
>=
base_row
+
block_size
:
break
phys
=
start_row
//
page_size
out
[
b
,
h
,
lp_idx
]
=
int
(
phys
)
lp_idx
+=
1
return
out
def
flatten_kv_cache_plane
(
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""View (num_blocks, block_size, HKV, D) caches as [num_blocks*block_size*HKV, D].
This matches compactor row indexing only when HKV == 1 (see module doc).
"""
if
num_kv_heads
!=
1
:
raise
ValueError
(
"flatten_kv_cache_plane requires num_kv_heads==1 for compactor layout"
)
if
key_cache
.
shape
!=
value_cache
.
shape
:
raise
ValueError
(
"key_cache and value_cache must match"
)
# [num_blocks, block_size, 1, D] -> [num_blocks * block_size, D]
nb
,
bs
,
hkv
,
d
=
key_cache
.
shape
if
hkv
!=
1
:
raise
ValueError
(
"expected num_kv_heads==1"
)
k_flat
=
key_cache
.
reshape
(
nb
*
bs
,
d
)
v_flat
=
value_cache
.
reshape
(
nb
*
bs
,
d
)
if
not
k_flat
.
is_contiguous
():
k_flat
=
k_flat
.
contiguous
()
if
not
v_flat
.
is_contiguous
():
v_flat
=
v_flat
.
contiguous
()
return
k_flat
,
v_flat
def
block_table_to_global_page_table
(
block_table
:
torch
.
Tensor
,
num_kv_heads
:
int
,
max_batches
:
int
,
)
->
torch
.
Tensor
:
"""Build [max_batches, HKV, num_logical_pages] int32 page table.
For MQA, every KV head reuses the same physical block ids as vLLM's table.
"""
# block_table: [num_reqs_padded, max_num_blocks]
bsz
,
max_lp
=
block_table
.
shape
if
bsz
>
max_batches
:
raise
ValueError
(
"batch size exceeds max_batches for page table"
)
out
=
torch
.
zeros
(
(
max_batches
,
num_kv_heads
,
max_lp
),
dtype
=
torch
.
int32
,
device
=
block_table
.
device
,
)
bt
=
block_table
.
to
(
torch
.
int32
)[:
bsz
]
if
num_kv_heads
==
1
:
out
[:
bsz
,
0
,
:
max_lp
]
=
bt
else
:
for
h
in
range
(
num_kv_heads
):
out
[:
bsz
,
h
,
:
max_lp
]
=
bt
return
out
def
build_batch_mapping
(
num_reqs
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""Local batch index -> global batch row (identity)."""
return
torch
.
arange
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
vllm/kvprune_legacy_save/utils/sequence.py
0 → 100644
View file @
d29c39ca
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
itertools
import
count
from
typing
import
List
from
vllm.kvprune.compression.compression_config
import
SequenceCompressionParams
from
vllm.kvprune.config.sampling_params
import
SamplingParams
class
SequenceStatus
(
Enum
):
WAITING
=
auto
()
RUNNING
=
auto
()
FINISHED
=
auto
()
@
dataclass
class
Sequence
:
"""
Represents a single user request / sequence being generated.
"""
_counter
=
count
()
prompt_token_ids
:
List
[
int
]
completion_token_ids
:
List
[
int
]
=
field
(
default_factory
=
list
)
sampling_params
:
SamplingParams
=
field
(
default_factory
=
SamplingParams
)
compression_params
:
SequenceCompressionParams
=
field
(
default_factory
=
SequenceCompressionParams
)
status
:
SequenceStatus
=
SequenceStatus
.
WAITING
seq_id
:
int
=
field
(
default_factory
=
lambda
:
next
(
Sequence
.
_counter
),
init
=
False
)
num_tokens_processed
:
int
=
0
@
property
def
num_prompt_tokens
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
num_generated_tokens
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
def
add_new_token
(
self
,
token_id
:
int
)
->
None
:
if
len
(
self
.
completion_token_ids
)
==
0
:
self
.
num_tokens_processed
+=
self
.
num_prompt_tokens
self
.
completion_token_ids
.
append
(
token_id
)
self
.
num_tokens_processed
+=
1
def
tokens_to_retain_per_layer
(
self
,
num_kv_heads
:
int
)
->
int
:
n
=
int
(
self
.
compression_params
.
compression_ratio
*
self
.
num_prompt_tokens
*
num_kv_heads
)
return
max
(
1
,
n
)
def
__getstate__
(
self
):
return
dict
(
prompt_token_ids
=
list
(
self
.
prompt_token_ids
),
completion_token_ids
=
list
(
self
.
completion_token_ids
),
sampling_params
=
self
.
sampling_params
,
compression_params
=
self
.
compression_params
,
status
=
self
.
status
,
seq_id
=
self
.
seq_id
,
num_tokens_processed
=
self
.
num_tokens_processed
,
)
def
__setstate__
(
self
,
state
):
self
.
prompt_token_ids
=
list
(
state
[
"prompt_token_ids"
])
self
.
completion_token_ids
=
list
(
state
[
"completion_token_ids"
])
self
.
sampling_params
=
state
[
"sampling_params"
]
self
.
compression_params
=
state
[
"compression_params"
]
self
.
status
=
state
[
"status"
]
self
.
seq_id
=
state
[
"seq_id"
]
self
.
num_tokens_processed
=
state
[
"num_tokens_processed"
]
@
property
def
prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
@
property
def
completion_len
(
self
)
->
int
:
return
len
(
self
.
completion_token_ids
)
vllm/kvprune_legacy_save/utils/tp_collectives.py
0 → 100644
View file @
d29c39ca
"""Tensor-parallel collectives for kvprune (match vLLM TP process group when embedded)."""
from
__future__
import
annotations
import
torch.distributed
as
dist
def
tensor_parallel_all_reduce
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce across tensor-parallel ranks (in-place on ``tensor`` when possible).
When vLLM :mod:`vllm.distributed.parallel_state` is initialized (e.g. kvprune
runs inside a vLLM GPU worker), uses the same TP NCCL group as the main model
(:func:`~vllm.distributed.communication_op.tensor_model_parallel_all_reduce`).
vLLM's TP :meth:`~vllm.distributed.parallel_state.GroupCoordinator.all_reduce`
is **out-of-place** and returns a new tensor. Call sites such as
:class:`~vllm.kvprune.layers.linear.RowParallelLinear` historically invoked
``tensor_parallel_all_reduce(y)`` without using the return value, which left
``y`` as the **unreduced** per-rank partial output under TP>1 — wrong activations,
wrong logits, and garbage tokens. We copy the reduced result back into ``tensor``
so existing call sites remain correct.
Standalone kvprune subprocesses only have the default process group (world ==
``tensor_parallel_size``); in that case we fall back to :func:`torch.distributed.all_reduce`
on the default group.
"""
if
not
dist
.
is_initialized
()
or
dist
.
get_world_size
()
<=
1
:
return
tensor
try
:
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
if
model_parallel_is_initialized
():
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_all_reduce
as
vllm_tp_all_reduce
,
)
reduced
=
vllm_tp_all_reduce
(
tensor
)
if
reduced
is
not
tensor
:
# vLLM TP all_reduce is out-of-place: `reduced` holds the cross-rank sum.
# Call sites ignore the return value and expect `tensor` to be updated — we
# MUST materialize the reduced values here or TP>1 keeps per-rank partials
# (RowParallel / VocabParallel outputs stay wrong without this copy).
tensor
.
copy_
(
reduced
)
return
tensor
except
Exception
:
pass
dist
.
all_reduce
(
tensor
)
return
tensor
vllm/kvprune_legacy_save/utils/tp_utils.py
0 → 100644
View file @
d29c39ca
"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""
from
__future__
import
annotations
import
torch.distributed
as
dist
def
tensor_parallel_rank_for_sharding
()
->
int
:
"""Rank within the tensor-parallel group (matches vLLM weight shards when embedded).
Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
unavailable (standalone kvprune with only the default process group).
"""
try
:
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
return
int
(
get_tensor_model_parallel_rank
())
except
Exception
:
if
dist
.
is_initialized
():
return
int
(
dist
.
get_rank
())
return
0
def
tensor_parallel_world_size_for_sharding
()
->
int
:
"""World size of the tensor-parallel group."""
try
:
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
return
int
(
get_tensor_model_parallel_world_size
())
except
Exception
:
if
dist
.
is_initialized
():
return
int
(
dist
.
get_world_size
())
return
1
def
kv_heads_shard_divisor
()
->
int
:
"""Return world size used to shard KV heads (TP group when vLLM is loaded)."""
return
tensor_parallel_world_size_for_sharding
()
vllm/kvprune_legacy_save/utils/triton_compat.py
0 → 100644
View file @
d29c39ca
from
__future__
import
annotations
import
inspect
from
typing
import
Any
,
Callable
,
Mapping
import
torch
def
_filter_kwargs_for_callable
(
fn
:
Callable
[...,
Any
],
kwargs
:
Mapping
[
str
,
Any
]
)
->
dict
[
str
,
Any
]:
try
:
params
=
inspect
.
signature
(
fn
).
parameters
except
(
TypeError
,
ValueError
):
return
dict
(
kwargs
)
return
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
params
}
def
autotune
(
*
,
configs
,
key
,
**
kwargs
):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import
triton
filtered
=
_filter_kwargs_for_callable
(
triton
.
autotune
,
kwargs
)
return
triton
.
autotune
(
configs
=
configs
,
key
=
key
,
**
filtered
)
def
maybe_set_allocator
(
alloc_fn
:
Callable
[[
int
,
int
,
int
|
None
],
Any
])
->
bool
:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import
triton
setter
=
getattr
(
triton
,
"set_allocator"
,
None
)
if
setter
is
None
:
return
False
setter
(
alloc_fn
)
return
True
def
cuda_capability_geq
(
major
:
int
,
minor
:
int
=
0
,
device
:
int
|
None
=
None
)
->
bool
:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if
not
torch
.
cuda
.
is_available
():
return
False
if
device
is
None
:
try
:
device
=
torch
.
cuda
.
current_device
()
except
Exception
:
device
=
0
cap
=
torch
.
cuda
.
get_device_capability
(
device
)
return
cap
>=
(
major
,
minor
)
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment