Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f81ce56b
Commit
f81ce56b
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.1
parent
2b7160c6
Changes
237
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3044 deletions
+0
-3044
vllm/kvprune_legacy_save/kv_cache/page_table.py
vllm/kvprune_legacy_save/kv_cache/page_table.py
+0
-313
vllm/kvprune_legacy_save/kv_cache/store_kv_cache.py
vllm/kvprune_legacy_save/kv_cache/store_kv_cache.py
+0
-468
vllm/kvprune_legacy_save/kv_cache/write_page_table.py
vllm/kvprune_legacy_save/kv_cache/write_page_table.py
+0
-110
vllm/kvprune_legacy_save/kvprune_to_vllm.md
vllm/kvprune_legacy_save/kvprune_to_vllm.md
+0
-56
vllm/kvprune_legacy_save/layers/__init__.py
vllm/kvprune_legacy_save/layers/__init__.py
+0
-9
vllm/kvprune_legacy_save/layers/activation.py
vllm/kvprune_legacy_save/layers/activation.py
+0
-13
vllm/kvprune_legacy_save/layers/attention.py
vllm/kvprune_legacy_save/layers/attention.py
+0
-212
vllm/kvprune_legacy_save/layers/embed_head.py
vllm/kvprune_legacy_save/layers/embed_head.py
+0
-111
vllm/kvprune_legacy_save/layers/layernorm.py
vllm/kvprune_legacy_save/layers/layernorm.py
+0
-49
vllm/kvprune_legacy_save/layers/linear.py
vllm/kvprune_legacy_save/layers/linear.py
+0
-158
vllm/kvprune_legacy_save/layers/moe.py
vllm/kvprune_legacy_save/layers/moe.py
+0
-177
vllm/kvprune_legacy_save/layers/rotary_embedding.py
vllm/kvprune_legacy_save/layers/rotary_embedding.py
+0
-121
vllm/kvprune_legacy_save/layers/sampler.py
vllm/kvprune_legacy_save/layers/sampler.py
+0
-27
vllm/kvprune_legacy_save/layers/triton_helpers.py
vllm/kvprune_legacy_save/layers/triton_helpers.py
+0
-101
vllm/kvprune_legacy_save/models/__init__.py
vllm/kvprune_legacy_save/models/__init__.py
+0
-20
vllm/kvprune_legacy_save/models/llama3.py
vllm/kvprune_legacy_save/models/llama3.py
+0
-299
vllm/kvprune_legacy_save/models/qwen3.py
vllm/kvprune_legacy_save/models/qwen3.py
+0
-296
vllm/kvprune_legacy_save/models/qwen3_moe.py
vllm/kvprune_legacy_save/models/qwen3_moe.py
+0
-406
vllm/kvprune_legacy_save/triton_kernels/__init__.py
vllm/kvprune_legacy_save/triton_kernels/__init__.py
+0
-22
vllm/kvprune_legacy_save/triton_kernels/compaction.py
vllm/kvprune_legacy_save/triton_kernels/compaction.py
+0
-76
No files found.
vllm/kvprune_legacy_save/kv_cache/page_table.py
deleted
100644 → 0
View file @
2b7160c6
import
heapq
import
logging
from
enum
import
Enum
,
auto
from
typing
import
List
,
Optional
,
Union
import
torch
from
vllm.kvprune.config.constants
import
RESERVED_BATCH
from
vllm.kvprune.kv_cache.write_page_table
import
scatter_to_page_table
logger
=
logging
.
getLogger
(
__name__
)
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
next_multiple
(
a
,
b
):
return
cdiv
(
a
,
b
)
*
b
class
KVAllocationStatus
(
Enum
):
EXCEEDS_MAX_SEQUENCE_LENGTH
=
auto
()
EXCEEDS_CURRENTLY_AVAILABLE_PAGES
=
auto
()
EXCEEDS_MAX_NUM_BATCHES
=
auto
()
SUCCESS
=
auto
()
class
PagedKVCache
(
torch
.
nn
.
Module
):
"""
Global paged KV cache.
This module manages:
* A global K/V backing buffer for all layers:
``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
where the first dimension indexes K vs V.
* A per-layer page table:
``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
mapping logical (batch, kv-head, logical_page) to a physical page ID
in the global K/V buffer.
* Per-layer, per-(batch, kv-head) logical sequence lengths
``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
the number of allocated pages ``bh_num_pages`` for each (layer, batch,
head).
* A page allocator implemented as a min-heap of free physical pages
per layer, plus free batch indices.
Pages are of fixed size ``page_size`` tokens.
Args:
:param num_layers:
Number of transformer layers that will use this cache.
:param max_logical_pages_per_head:
Maximum number of logical pages that can be assigned to a single
(batch, kv-head) pair.
:param num_pages:
Total number of physical pages available in the global cache per
layer. The global K/V buffers are of length
``num_pages * page_size`` along the token dimension.
:param page_size:
Number of tokens stored per page.
:param H_kv:
Number of KV heads per layer.
:param head_dim:
Head dimension for K/V.
:param max_num_batches:
Maximum number of concurrent batches / sequences supported. One
batch index is reserved for internal use (``RESERVED_BATCH``).
:param dtype:
Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
:param device:
Device on which to allocate the cache (string, torch.device, or
int; defaults to ``"cuda"``).
"""
def
__init__
(
self
,
num_layers
:
int
,
max_logical_pages_per_head
:
int
,
num_pages
:
int
,
page_size
:
int
,
# tokens per page
H_kv
:
int
,
head_dim
:
int
,
max_num_batches
:
int
,
dtype
:
torch
.
dtype
,
device
:
Union
[
str
,
torch
.
device
,
int
]
=
"cuda"
,
):
super
().
__init__
()
self
.
n_pages
=
num_pages
self
.
num_layers
=
num_layers
self
.
page_size
:
int
=
int
(
page_size
)
self
.
H_kv
=
int
(
H_kv
)
self
.
max_pages_per_head
=
max_logical_pages_per_head
max_num_batches
+=
1
self
.
max_num_batches
=
max_num_batches
self
.
head_dim
=
head_dim
cache_shape
=
(
2
,
num_layers
,
num_pages
*
page_size
,
head_dim
)
self
.
kv_cache
=
torch
.
empty
(
cache_shape
,
dtype
=
dtype
,
device
=
device
)
self
.
page_table
=
torch
.
empty
(
(
num_layers
,
max_num_batches
,
H_kv
,
self
.
max_pages_per_head
),
device
=
device
,
dtype
=
torch
.
int32
,
)
# Per-(batch, head) logical seq length (tokens)
self
.
bh_seq_lens
=
torch
.
zeros
(
(
num_layers
,
max_num_batches
,
H_kv
),
device
=
device
,
dtype
=
torch
.
int32
)
# self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
self
.
bh_num_pages
=
torch
.
zeros
(
(
num_layers
,
max_num_batches
,
H_kv
),
device
=
device
,
dtype
=
torch
.
int32
)
# Page allocator (min-heap of free physical pages)
self
.
free_pages
:
List
[
List
[
int
]]
=
[
list
(
range
(
num_pages
))
for
_
in
range
(
num_layers
)
]
for
free_pages
in
self
.
free_pages
:
heapq
.
heapify
(
free_pages
)
# batch zero is reserved
self
.
free_batches
:
List
[
int
]
=
list
(
reversed
(
range
(
max_num_batches
)))
self
.
free_batches
.
remove
(
RESERVED_BATCH
)
# Record of physical page ids owned by a batch (for freeing)
self
.
pages_indices_per_batch
:
List
[
List
[
set
[
int
]]]
=
[
[
set
()
for
_
in
range
(
num_layers
)]
for
_
in
range
(
max_num_batches
)
]
def
new_batch
(
self
)
->
Optional
[
int
]:
"""
Reserve a new batch slot.
A batch slot corresponds to a row in ``bh_seq_lens`` /
``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
heads. This method checks whether a free batch index is available, and
whether each layer has at least ``H_kv`` free pages remaining.
If both checks pass, it returns a batch index and removes it from
``free_batches``. Otherwise, it returns ``None``.
Returns:
:return Optional[int]:
Newly reserved batch index, or ``None`` if no capacity is
available.
"""
if
self
.
free_batches
and
all
([
self
.
H_kv
<=
len
(
fp
)
for
fp
in
self
.
free_pages
]):
return
self
.
free_batches
.
pop
()
return
None
def
reserve_tokens
(
self
,
batch_index
:
int
,
add_tokens
:
int
)
->
KVAllocationStatus
:
"""
Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
Args:
:param batch_index:
Batch index to reserve space for.
:param add_tokens:
Number of additional tokens to reserve capacity for.
All heads in this batch and all layers reserve
the same number of extra tokens.
Returns:
:return bool:
``True`` if the reservation succeeds; ``False`` otherwise .
"""
cur_bh_lens
=
self
.
bh_seq_lens
[:,
batch_index
]
# [L, H]
curr_pages
=
self
.
bh_num_pages
[:,
batch_index
]
# [L, H]
curr_cap_tokens
=
curr_pages
*
self
.
page_size
# [L, H]
need_tokens
=
cur_bh_lens
+
add_tokens
# [L, H]
if
(
need_tokens
<=
curr_cap_tokens
).
all
():
return
KVAllocationStatus
.
SUCCESS
missing_tokens
=
need_tokens
-
curr_cap_tokens
add_pages
=
cdiv
(
missing_tokens
,
self
.
page_size
)
new_total_pages
=
curr_pages
+
add_pages
if
(
new_total_pages
>
self
.
max_pages_per_head
).
any
():
return
KVAllocationStatus
.
EXCEEDS_MAX_SEQUENCE_LENGTH
# CPU work
pages_per_layer_cpu
=
add_pages
.
sum
(
dim
=-
1
).
tolist
()
new_phys_pages
=
[]
for
layer_index
in
range
(
self
.
num_layers
):
if
pages_per_layer_cpu
[
layer_index
]
>
len
(
self
.
free_pages
[
layer_index
]):
return
KVAllocationStatus
.
EXCEEDS_CURRENTLY_AVAILABLE_PAGES
for
layer_index
in
range
(
self
.
num_layers
):
this_layer_pages
=
[
heapq
.
heappop
(
self
.
free_pages
[
layer_index
])
for
_
in
range
(
pages_per_layer_cpu
[
layer_index
])
]
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]
|=
set
(
this_layer_pages
)
new_phys_pages
.
extend
(
this_layer_pages
)
new_phys_pages
=
torch
.
tensor
(
new_phys_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
scatter_to_page_table
(
add_pages
=
add_pages
,
new_phys_pages
=
new_phys_pages
,
curr_pages
=
curr_pages
,
page_table
=
self
.
page_table
[:,
batch_index
],
max_pages_per_head
=
self
.
max_pages_per_head
,
)
self
.
bh_num_pages
[:,
batch_index
,
:]
=
new_total_pages
.
to
(
self
.
bh_num_pages
.
dtype
)
return
KVAllocationStatus
.
SUCCESS
def
reclaim_pages
(
self
,
batch_index
:
int
,
future_reserve_tokens
:
int
=
0
,
):
"""
Reclaim unused pages for a single batch index. This shrinks the KV
allocation for the batch down to the minimum number of pages needed
to hold the current (plus optional future) sequence length.
Args:
:param batch_index:
Batch index whose pages should be compacted.
:param future_reserve_tokens:
Optional number of extra tokens to keep capacity for, beyond
the current sequence length. This can reduce churn when
sequences are expected to grow slightly in the near future.
Returns:
:return int:
Approximate number of bytes freed across both K and V.
"""
device
=
self
.
bh_seq_lens
.
device
L
,
B
,
H
=
self
.
bh_seq_lens
.
shape
assert
0
<=
batch_index
<
B
seq
=
self
.
bh_seq_lens
[:,
batch_index
,
:]
+
future_reserve_tokens
# [L, H]
alloc
=
self
.
bh_num_pages
[:,
batch_index
,
:]
# [L, H]
pt
=
self
.
page_table
[:,
batch_index
,
:,
:].
reshape
(
-
1
)
# [L, H, P]
# Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
used_pages
=
cdiv
(
seq
,
self
.
page_size
)
used_pages
=
torch
.
minimum
(
used_pages
,
alloc
)
# page indices [0..P-1], broadcasted over [L, H, P]
p
=
torch
.
arange
(
self
.
max_pages_per_head
,
device
=
device
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
self
.
max_pages_per_head
)
# allocated: p < alloc
alloc_mask
=
p
<
alloc
.
unsqueeze
(
-
1
)
# [L, H, P]
# to free: allocated and p in [used_pages, alloc)
free_mask
=
alloc_mask
&
(
p
>=
used_pages
.
unsqueeze
(
-
1
))
free_mask_flat
=
free_mask
.
view
(
-
1
)
# [L*H*P]
if
not
free_mask_flat
.
any
():
return
0
idx
=
free_mask_flat
.
nonzero
(
as_tuple
=
False
).
squeeze
(
-
1
)
# indices of freed slots
# Freed physical page ids
freed_pages
=
pt
[
idx
]
# Compute layer index for each freed slot:
# layout is [L, H, P] → flat index = ((l * H) + h) * P + p
freed_layers
=
(
idx
//
(
H
*
self
.
max_pages_per_head
)).
to
(
torch
.
int32
)
freed_pages
=
freed_pages
.
tolist
()
layer_mapping
=
freed_layers
.
tolist
()
self
.
bh_num_pages
[:,
batch_index
,
:]
=
used_pages
for
page
,
layer
in
zip
(
freed_pages
,
layer_mapping
):
self
.
pages_indices_per_batch
[
batch_index
][
layer
].
remove
(
page
)
heapq
.
heappush
(
self
.
free_pages
[
layer
],
page
)
approximate_bytes_freed
=
(
len
(
freed_pages
)
*
(
self
.
page_size
*
self
.
head_dim
*
self
.
kv_cache
.
element_size
())
*
2
)
# multiply for two for K + V
return
approximate_bytes_freed
def
_free_batch_layer
(
self
,
layer_index
:
int
,
batch_index
:
int
)
->
None
:
"""
Free all pages belonging to batch_index and reset its metadata.
"""
# Return pages to the global heap
for
phys
in
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]:
heapq
.
heappush
(
self
.
free_pages
[
layer_index
],
int
(
phys
))
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]
=
set
()
def
free_batch
(
self
,
batch_index
:
int
)
->
None
:
"""
Free all resources associated with a batch index.
Args:
:param batch_index:
Batch index to release. Must have been previously allocated
via :meth:`new_batch`.
"""
for
layer
in
range
(
self
.
num_layers
):
self
.
_free_batch_layer
(
layer
,
batch_index
)
self
.
bh_seq_lens
[:,
batch_index
].
zero_
()
self
.
bh_num_pages
[:,
batch_index
].
zero_
()
self
.
free_batches
.
append
(
batch_index
)
def
layer_slices
(
self
,
layer
:
int
):
"""
Return layer-local views needed by the attention module.
For a given ``layer`` index, this method returns the slices of the
global K/V cache, page table, and per-(batch, head) sequence lengths
corresponding to that layer.
Args:
:param layer:
Layer index ``l`` in ``[0, num_layers)``.
Returns:
:return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``(k, v, pt, bh)`` as described above.
"""
assert
0
<=
layer
<
self
.
num_layers
k
=
self
.
kv_cache
[
0
,
layer
]
v
=
self
.
kv_cache
[
1
,
layer
]
pt
=
self
.
page_table
[
layer
]
bh
=
self
.
bh_seq_lens
[
layer
]
return
k
,
v
,
pt
,
bh
vllm/kvprune_legacy_save/kv_cache/store_kv_cache.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
from
vllm.kvprune.config.constants
import
(
TRITON_RESERVED_BATCH
as
_TRITON_RESERVED_BATCH
,
)
@
triton
.
jit
def
_prefill_store_topk_kv_kernel
(
key
,
value
,
# [N_total, H, D] (D stride assumed 1)
batch_mapping
,
# [B] int32 (local b -> true batch)
num_tokens_to_retain
,
# [B] int32
indices_topk
,
# [B, MAX_SEL] int32 (across all heads)
# Lengths & page table:
bh_lens
,
# [B, H] int32 (contiguous)
page_table
,
# [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
k_cache
,
v_cache
,
# [N_PAGES * PAGE_SIZE, D]
sk_n
,
sk_h
,
# strides for key,value. D stride assumed 1
sv_n
,
sv_h
,
# Runtime ints
MAX_SEL
,
# num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
K_TILE
:
tl
.
constexpr
,
# how many selected tokens each program processes
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
b_local
=
tl
.
program_id
(
0
)
tile_id
=
tl
.
program_id
(
1
)
offs
=
tl
.
arange
(
0
,
D
)
# how many tokens we actually keep for this batch
k_total
=
tl
.
load
(
num_tokens_to_retain
+
b_local
)
if
k_total
==
0
:
return
# map to true batch row in the page table
b_true
=
tl
.
load
(
batch_mapping
+
b_local
)
if
b_true
==
TRITON_RESERVED_BATCH
:
return
base
=
tile_id
*
K_TILE
# process up to K_TILE tokens
for
j
in
tl
.
range
(
0
,
K_TILE
):
sel_idx
=
base
+
j
if
sel_idx
<
k_total
and
sel_idx
<
MAX_SEL
:
# flattened selection: sel = token * H + head
sel
=
tl
.
load
(
indices_topk
+
b_local
*
MAX_SEL
+
sel_idx
)
tok
=
sel
//
HKV
head
=
sel
-
(
tok
*
HKV
)
# atomically reserve one position in (b_local, hed)
# i.e the KV cache is scrambled when storing
len_ptr
=
bh_lens
+
b_local
*
HKV
+
head
pos
=
tl
.
atomic_add
(
len_ptr
,
1
)
# old length (int32)
lp
=
pos
//
PAGE_SIZE
off
=
pos
-
lp
*
PAGE_SIZE
# translate logical page to physical page
pt_base
=
(
b_true
*
HKV
+
head
)
*
N_LOGICAL_PAGES_MAX
phys
=
tl
.
load
(
page_table
+
pt_base
+
lp
).
to
(
tl
.
int64
)
# destination row and element offset
dst_row
=
phys
*
PAGE_SIZE
+
off
dst_off
=
dst_row
*
D
+
offs
# load one vector from [N_total, H, D]
k_src
=
key
+
tok
*
sk_n
+
head
*
sk_h
+
offs
v_src
=
value
+
tok
*
sv_n
+
head
*
sv_h
+
offs
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
,
cache_modifier
=
".cv"
,
eviction_policy
=
"evict_first"
),
eviction_policy
=
"evict_first"
,
)
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
,
cache_modifier
=
".cv"
,
eviction_policy
=
"evict_first"
),
eviction_policy
=
"evict_first"
,
)
def
prefill_store_topk_kv
(
*
,
new_keys
:
torch
.
Tensor
,
# [N_total, H, D]
new_vals
:
torch
.
Tensor
,
# [N_total, H, D]
indices_topk
:
torch
.
Tensor
,
# [B, MAX_SEL] int32 (global flattened token*H + head)
num_tokens_to_retain
:
torch
.
Tensor
,
# [B] int32
page_table
:
torch
.
Tensor
,
# [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local -> true batch rows)
bh_lens
:
torch
.
Tensor
,
# [B, H] int32 (contiguous), UPDATED atomically
k_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
v_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE
:
int
,
PAD_TO_PAGE_SIZE
:
bool
=
True
,
cu_seqlens_k
:
torch
.
Tensor
|
None
=
None
,
K_TILE
:
int
=
16
,
TRITON_RESERVED_BATCH
:
int
=
None
,
):
assert
new_keys
.
shape
==
new_vals
.
shape
N_total
,
H
,
D
=
new_keys
.
shape
B
=
indices_topk
.
shape
[
0
]
assert
page_table
.
shape
[
1
]
==
H
assert
bh_lens
.
shape
==
(
B
,
H
)
assert
new_keys
.
device
==
k_cache
.
device
==
v_cache
.
device
assert
page_table
.
is_contiguous
(),
"page table must be contiguous."
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous."
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous."
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
assert
new_keys
.
stride
(
-
1
)
==
1
and
new_vals
.
stride
(
-
1
)
==
1
,
(
"new_keys/new_vals last dim must be contiguous."
)
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
page_table
=
page_table
.
to
(
torch
.
int32
)
bh_lens
=
bh_lens
.
to
(
torch
.
int32
)
batch_mapping
=
batch_mapping
.
to
(
torch
.
int32
)
indices_topk
=
indices_topk
.
to
(
torch
.
int32
)
num_tokens_to_retain
=
num_tokens_to_retain
.
to
(
torch
.
int32
)
# strides (elements) for [N_total, H, D]
sk_n
,
sk_h
,
_
=
new_keys
.
stride
()
sv_n
,
sv_h
,
_
=
new_vals
.
stride
()
# tile second grid dim
MAX_SEL
=
indices_topk
.
shape
[
-
1
]
N_TILES
=
(
MAX_SEL
+
K_TILE
-
1
)
//
K_TILE
grid
=
(
B
,
max
(
1
,
N_TILES
))
if
TRITON_RESERVED_BATCH
is
None
:
TRITON_RESERVED_BATCH
=
_TRITON_RESERVED_BATCH
_prefill_store_topk_kv_kernel
[
grid
](
key
=
new_keys
,
value
=
new_vals
,
batch_mapping
=
batch_mapping
,
num_tokens_to_retain
=
num_tokens_to_retain
,
indices_topk
=
indices_topk
,
bh_lens
=
bh_lens
,
page_table
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
MAX_SEL
=
int
(
MAX_SEL
),
HKV
=
H
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
K_TILE
=
K_TILE
,
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
,
)
if
PAD_TO_PAGE_SIZE
:
assert
cu_seqlens_k
is
not
None
assert
indices_topk
.
is_contiguous
()
assert
page_table
.
is_contiguous
()
_prefill_store_topk_pad_kernel
[(
B
,
H
)](
key
=
new_keys
,
value
=
new_vals
,
batch_mapping
=
batch_mapping
,
num_tokens_to_retain
=
num_tokens_to_retain
,
indices
=
indices_topk
,
local_lens
=
bh_lens
,
page_table_flat
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
cu_seqlens_k
=
cu_seqlens_k
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
MAX_SEL
=
int
(
MAX_SEL
),
H
=
H
,
# type: ignore
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
# type: ignore
D
=
D
,
# type: ignore
PAGE_SIZE
=
PAGE_SIZE
,
# type: ignore
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
,
)
@
triton
.
jit
def
_prefill_store_topk_pad_kernel
(
key
,
# [N_total, H, D]
value
,
# [N_total, H, D]
batch_mapping
,
# [B] int32 (local b -> true batch)
num_tokens_to_retain
,
# [B] int32
indices
,
# [B, MAX_SEL] int32 (across all heads)
local_lens
,
# [B, H] int32 (contiguous)
page_table_flat
,
# [B_total*H*N_LOGICAL_PAGES_MAX] int32
k_cache
,
v_cache
,
# [N_PAGES*PAGE_SIZE, D]
cu_seqlens_k
,
sk_n
,
sk_h
,
sv_n
,
sv_h
,
MAX_SEL
,
# Constexprs
H
:
tl
.
constexpr
,
# number of KV heads
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
b_local
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
D
)
L
=
tl
.
load
(
local_lens
+
b_local
*
H
+
h
)
modulo_page_size
=
L
-
(
L
//
PAGE_SIZE
)
*
PAGE_SIZE
if
modulo_page_size
==
0
:
return
need
=
PAGE_SIZE
-
modulo_page_size
b_true
=
tl
.
load
(
batch_mapping
+
b_local
)
if
b_true
==
TRITON_RESERVED_BATCH
:
return
pt_base
=
(
b_true
*
H
+
h
)
*
N_LOGICAL_PAGES_MAX
written_tokens
=
0
idx
=
tl
.
load
(
num_tokens_to_retain
+
b_local
)
this_batch_ctx_len
=
tl
.
load
(
cu_seqlens_k
+
b_local
+
1
)
-
tl
.
load
(
cu_seqlens_k
+
b_local
)
max_additional
=
this_batch_ctx_len
-
L
while
(
written_tokens
<
need
and
idx
<
MAX_SEL
)
and
(
written_tokens
<
max_additional
):
# candidate head
cand_idx
=
tl
.
load
(
indices
+
b_local
*
MAX_SEL
+
idx
)
cand_h
=
cand_idx
%
H
if
cand_h
==
h
:
tok
=
cand_idx
//
H
pos
=
L
+
written_tokens
lp
=
pos
//
PAGE_SIZE
off
=
pos
-
lp
*
PAGE_SIZE
phys
=
tl
.
load
(
page_table_flat
+
pt_base
+
lp
).
to
(
tl
.
int32
)
dst_row
=
phys
*
PAGE_SIZE
+
off
dst_off
=
dst_row
.
to
(
tl
.
int64
)
*
D
+
offs_d
k_src
=
key
+
tok
*
sk_n
+
h
*
sk_h
+
offs_d
v_src
=
value
+
tok
*
sv_n
+
h
*
sv_h
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
),
)
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
),
)
written_tokens
+=
1
idx
+=
1
tl
.
store
(
local_lens
+
b_local
*
H
+
h
,
L
+
written_tokens
)
@
triton
.
jit
def
_prefill_store_all_kv_kernel
(
key
,
value
,
# [N, H, D] (D contiguous)
cu_seqlens_k
,
# [B + 1] int32
batch_mapping
,
# [B] int32 (local b -> true batch index)
bh_lens
,
# [B * HKV] int32 (UPDATED)
pt_flat
,
# [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
k_cache
,
v_cache
,
# [N_PAGES * PAGE_SIZE, D]
# source strides (elements)
sk_n
,
sk_h
,
sv_n
,
sv_h
,
# constexpr
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
K_TILE
:
tl
.
constexpr
,
# number of (token, head) pairs processed per program
):
pid_b
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
start
=
tl
.
load
(
cu_seqlens_k
+
pid_b
)
end
=
tl
.
load
(
cu_seqlens_k
+
pid_b
+
1
)
num_toks_this_batch
=
end
-
start
if
num_toks_this_batch
<=
0
:
return
total_elems
=
num_toks_this_batch
*
HKV
# base linear index in (token, head) grid for this program
base
=
pid_blk
*
K_TILE
offs_d
=
tl
.
arange
(
0
,
D
)
# Iterate K_TILE elements in this tile
for
i
in
tl
.
range
(
0
,
K_TILE
):
idx
=
base
+
i
if
idx
<
total_elems
:
# map linear idx -> (t, h)
t
=
idx
//
HKV
h
=
idx
-
t
*
HKV
len_idx
=
pid_b
*
HKV
+
h
L0
=
tl
.
load
(
bh_lens
+
len_idx
)
token_idx_in_cache
=
L0
+
t
lp
=
token_idx_in_cache
//
PAGE_SIZE
# logical page
off_in_pg
=
token_idx_in_cache
-
lp
*
PAGE_SIZE
# pos in page
# physical page
b_true
=
tl
.
load
(
batch_mapping
+
pid_b
).
to
(
tl
.
int32
)
pt_base
=
(
b_true
*
HKV
+
h
)
*
N_LOGICAL_PAGES_MAX
phys
=
tl
.
load
(
pt_flat
+
pt_base
+
lp
).
to
(
tl
.
int64
)
row
=
phys
*
PAGE_SIZE
+
off_in_pg
dst_off
=
row
*
D
+
offs_d
n_global
=
(
start
+
t
).
to
(
tl
.
int64
)
# Use strides for non-contiguous [N, H, D] (D stride == 1)
k_src
=
key
+
n_global
*
sk_n
+
h
*
sk_h
+
offs_d
v_src
=
value
+
n_global
*
sv_n
+
h
*
sv_h
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
))
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
))
def
prefill_store_all_kv
(
*
,
new_keys
:
torch
.
Tensor
,
new_values
:
torch
.
Tensor
,
# [N, H_kv, D]
cu_seqlens_k
:
torch
.
Tensor
,
# [B + 1] int32
max_seqlen_k
:
int
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
# [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
bh_lens
:
torch
.
Tensor
,
# [B, H_kv] int32 (UPDATED)
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local->true)
PAGE_SIZE
:
int
,
K_TILE
:
int
=
32
,
# how many (token, head) pairs per program
):
assert
new_keys
.
stride
(
-
1
)
==
1
and
new_values
.
stride
(
-
1
)
==
1
,
(
"last dim must be contiguous"
)
assert
page_table
.
is_contiguous
(),
"page table must be contiguous"
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous"
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous"
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
N
,
HKV
,
D
=
new_keys
.
shape
B
=
batch_mapping
.
shape
[
0
]
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
sk_n
,
sk_h
,
_
=
new_keys
.
stride
()
sv_n
,
sv_h
,
_
=
new_values
.
stride
()
n_tiles
=
(
max_seqlen_k
*
HKV
+
K_TILE
-
1
)
//
K_TILE
grid
=
(
B
,
n_tiles
)
_prefill_store_all_kv_kernel
[
grid
](
new_keys
,
new_values
,
cu_seqlens_k
,
batch_mapping
,
bh_lens
,
page_table
,
k_cache
,
v_cache
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
HKV
=
HKV
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
-
1
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
K_TILE
=
K_TILE
,
)
bh_lens
+=
cu_seqlens_k
.
diff
()[:,
None
]
@
triton
.
jit
def
_decode_store_kv_kernel
(
key
,
value
,
batch_mapping
,
# [B] int32
bh_lens
,
# [B*HKV] int32
page_table
,
# [B_total*HKV*N_LOGICAL_PAGES_MAX]
k_cache
,
v_cache
,
# [N_PAGES*PAGE_SIZE, D]
sk_b
,
sk_h
,
sv_b
,
sv_h
,
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
if
mapped_b
==
TRITON_RESERVED_BATCH
:
return
offs_d
=
tl
.
arange
(
0
,
D
)
length
=
tl
.
load
(
bh_lens
+
pid_b
*
HKV
+
h
)
logical_page
=
length
//
PAGE_SIZE
internal_offset
=
length
-
logical_page
*
PAGE_SIZE
pt_base
=
(
mapped_b
*
HKV
+
h
)
*
N_LOGICAL_PAGES_MAX
physical_page
=
tl
.
load
(
page_table
+
pt_base
+
logical_page
).
to
(
tl
.
int64
)
dst_row
=
physical_page
*
PAGE_SIZE
+
internal_offset
# Source addressing using strides (D stride == 1)
k_src
=
key
+
pid_b
*
sk_b
+
h
*
sk_h
+
offs_d
v_src
=
value
+
pid_b
*
sv_b
+
h
*
sv_h
+
offs_d
dst_off
=
dst_row
*
D
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
))
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
))
tl
.
store
(
bh_lens
+
pid_b
*
HKV
+
h
,
length
+
1
)
def
decode_store_kv
(
*
,
key
:
torch
.
Tensor
,
# [B, HKV, D]
value
:
torch
.
Tensor
,
# [B, HKV, D]
batch_mapping
:
torch
.
Tensor
,
# [B] int32
bh_lens
:
torch
.
Tensor
,
# [B, HKV] or flattened [B*HKV] int32
page_table
:
torch
.
Tensor
,
# [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
# [N_PAGES*PAGE_SIZE, D]
PAGE_SIZE
:
int
,
TRITON_RESERVED_BATCH
:
int
=
None
,
):
assert
key
.
shape
==
value
.
shape
and
key
.
ndim
==
3
,
"key/value must be [B, HKV, D]"
B
,
HKV
,
D
=
key
.
shape
assert
key
.
stride
(
-
1
)
==
1
and
value
.
stride
(
-
1
)
==
1
,
(
"key/value last dim must be contiguous."
)
assert
page_table
.
is_contiguous
(),
"page table must be contiguous."
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous."
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous."
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
sk_b
,
sk_h
,
_
=
key
.
stride
()
sv_b
,
sv_h
,
_
=
value
.
stride
()
grid
=
(
int
(
batch_mapping
.
shape
[
0
]),
HKV
,
)
_decode_store_kv_kernel
[
grid
](
key
=
key
,
value
=
value
,
batch_mapping
=
batch_mapping
,
bh_lens
=
bh_lens
,
page_table
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
sk_b
=
sk_b
,
sk_h
=
sk_h
,
sv_b
=
sv_b
,
sv_h
=
sv_h
,
HKV
=
HKV
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
if
TRITON_RESERVED_BATCH
is
not
None
else
_TRITON_RESERVED_BATCH
,
)
vllm/kvprune_legacy_save/kv_cache/write_page_table.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
def
scatter_to_page_table
(
add_pages
:
torch
.
Tensor
,
# [L, H] int32
new_phys_pages
:
torch
.
Tensor
,
# [N]
curr_pages
:
torch
.
Tensor
,
# [L, H] int32
page_table
:
torch
.
Tensor
,
# [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
max_pages_per_head
:
int
,
):
"""
Append newly allocated physical pages into a layered page table via Triton.
For each (layer ``l``, head ``h``):
Args:
:param add_pages:
Tensor of shape ``[L, H]`` (int32) indicating how many pages to
append for each (layer, head).
:param new_phys_pages:
1D tensor of shape ``[N]`` (int32) containing physical page IDs
for all (layer, head) pairs, concatenated in row-major (L, H)
order. ``N`` must equal ``add_pages.sum()``.
:param curr_pages:
Tensor of shape ``[L, H]`` (int32) with the current logical page
counts per (layer, head) before this update.
:param page_table:
Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
the logical to physical page mapping. The last dimension is
logically indexed as logical_page ∈ [0, max_pages_per_head).
:param max_pages_per_head:
Maximum number of logical pages permitted per (layer, head). The
kernel skips writes beyond this bound.
Returns:
None. The function updates ``page_table`` in-place.
"""
L
,
H
=
add_pages
.
shape
if
L
==
0
or
H
==
0
:
return
add_flat
=
add_pages
.
to
(
torch
.
int32
).
contiguous
().
view
(
-
1
)
curr_flat
=
curr_pages
.
to
(
torch
.
int32
).
contiguous
().
view
(
-
1
)
cum_page_heads
=
torch
.
empty
(
L
*
H
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
cum_page_heads
[
0
]
=
0
torch
.
cumsum
(
add_flat
,
0
,
out
=
cum_page_heads
[
1
:])
stride_pl
,
stride_ph
,
stride_pp
=
page_table
.
stride
()
grid
=
(
L
,
H
)
_scatter_pages_kernel_lh
[
grid
](
add_flat
,
cum_page_heads
,
new_phys_pages
,
curr_flat
,
page_table
,
stride_pl
,
stride_ph
,
stride_pp
,
L
=
L
,
H
=
H
,
max_pages_per_head
=
max_pages_per_head
,
)
@
triton
.
jit
def
_scatter_pages_kernel_lh
(
add_pages
,
# int32 [L*H]
cum_page_heads
,
# int32 [L*H], base offset in flat_new_phys per (l,h)
flat_new_phys
,
# int32 [total_pages]
curr_pages
,
# int32 [L*H], existing logical pages per (l,h)
page_table_ptr
,
# int32* base pointer to page_table
stride_pl
,
# int, stride for layer dim
stride_ph
,
# int, stride for head dim
stride_pp
,
# int, stride for page dim
L
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
max_pages_per_head
:
tl
.
constexpr
,
):
layer_idx
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
if
layer_idx
>=
L
or
h
>=
H
:
return
lh
=
layer_idx
*
H
+
h
ap
=
tl
.
load
(
add_pages
+
lh
)
if
ap
<=
0
:
return
base
=
tl
.
load
(
cum_page_heads
+
lh
)
cp
=
tl
.
load
(
curr_pages
+
lh
)
# Append ap pages: logical pages [cp .. cp+ap)
for
i
in
tl
.
range
(
0
,
ap
):
phys
=
tl
.
load
(
flat_new_phys
+
base
+
i
)
lp
=
cp
+
i
if
lp
<
max_pages_per_head
:
offset
=
layer_idx
*
stride_pl
+
h
*
stride_ph
+
lp
*
stride_pp
tl
.
store
(
page_table_ptr
+
offset
,
phys
)
# TODO: write reclaim kernel
@
triton
.
jit
def
reclaim_page_kernel
():
pass
def
reclaim_pages
(
batch_index
:
int
,
bh_seq_lens
:
torch
.
Tensor
,
bh_num_pages
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
):
pass
vllm/kvprune_legacy_save/kvprune_to_vllm.md
deleted
100644 → 0
View file @
2b7160c6
# KV-prune 与上游 vLLM 的集成说明
本文说明:
**剪枝/压缩(Compactor)功能**
在「官网 vLLM 主仓库」里改动了哪些位置、是否只有少量文件、以及随 vLLM 版本升级时如何预期合并成本。
## 1. 是否「仅仅」改了少数几个脚本?
**核心运行时接线**
确实集中在少数几个
**非**
`vllm/kvprune/`
下的文件;功能主体在
`vllm/kvprune/`
包内独立维护。
| 路径 | 作用简述 |
|------|-----------|
|
`vllm/env_override.py`
| 在
`import vllm`
最早阶段设置与 kvprune 相关的默认环境变量(如 v1 多进程默认、压缩默认开关、可选释放 v1 KV 等)。 |
|
`vllm/__init__.py`
| 对外导出
`CompressionParams`
(懒加载至
`vllm.kvprune.integration.compression_params`
)。 |
|
`vllm/entrypoints/llm.py`
|
`kvprune_compression`
参数、
`generate(..., compression=...)`
、v1
`enforce_eager`
/
`num_gpu_blocks_override`
策略、懒加载 compactor、委托
`compressed_generate`
。 |
|
`vllm/v1/worker/gpu_worker.py`
|
`kvprune_v1_compressed_generate`
:供
`collective_rpc`
调用的 TP 多卡压缩生成入口。 |
|
`tests/conftest.py`
| 测试在导入 vLLM 前覆盖部分
`VLLM_KVPRUNE_*`
默认值,避免全量测试默认走压缩路径。 |
|
`vllm\vllm\envs.py`
| envs.py 中对 VLLM_KVPRUNE_
*
的集中注册 |
**此外(可选/示例,非引擎必需):**
-
`examples/offline_inference/`
下若干
`*kvprune*`
示例脚本:演示用法,不参与核心引擎加载。
**结论:**
-
**「官网 vLLM 主包」里与 kvprune 强相关的改动,主要就是上表 4 个文件 + 测试根配置**
(若把测试也算进「集成面」,共 5 处常见提法)。
-
**算法、Compactor、TP 内嵌 runner 等**
均在
`vllm/kvprune/`
(及该目录下的
`integration/`
)中,与上游 diff 相对隔离。
## 2. 随 vLLM 版本更新,是否「很容易」同步剪枝压缩功能?
**相对容易的部分:**
-
**集成面小**
:合并冲突主要出现在上述少数文件,而不是遍布整个 executor / attention / model 层。
-
**逻辑内聚**
:大量代码在
`vllm/kvprune/`
,可整体移植或
`git`
三方合并时以子树为主处理。
**仍需人工跟进的点(不能假设「自动无痛」):**
-
**`entrypoints/llm.py` 属于高频变更文件**
:上游每次大版本可能重构
`LLM`
构造参数、
`generate`
签名或引擎初始化;需要
**逐次解决冲突**
并回归压缩路径。
-
**`v1/worker/gpu_worker.py`**
同样会随 executor / RPC 接口变动;
`collective_rpc`
方法名或 worker 基类若有变化,需对齐。
-
**`env_override.py`**
若上游调整导入顺序或新增全局默认环境变量,需避免覆盖冲突或行为打架。
-
**vLLM v1 内部 API**
(如
`worker.get_model()`
、
`vllm_config`
结构)若变更,
`vllm/kvprune/integration/*`
也可能要跟着改——这类改动
**不在**
「仅 5 个文件」里,但仍是
**集成层**
维护成本。
**建议同步流程(简版):**
1.
在新上游 tag 上先合并/应用
`vllm/kvprune/`
目录。
2.
再手动合并上述 4 个主包文件 +
`tests/conftest.py`
。
3.
跑与 kvprune 相关的测试与至少一条离线
`compression`
示例。
4.
关注发行说明中
`LLM`
、
`EngineArgs`
、
`gpu_worker`
、多进程默认的破坏性变更。
## 3. 与「深度改内核」方案的区别
当前设计
**没有**
在
`model_executor`
的统一注意力路径上大规模插入 kvprune 钩子(相关辅助逻辑主要在
`vllm/kvprune`
内部)。因此:
-
**上游同步时**
,通常不必与 FlashAttention / 每层模型代码逐文件对打;
-
**代价是**
:功能边界以「共享权重 + compactor 引擎 + 可选 TP RPC」为主,与「原生 KV 算子级一体化」的改动面不同。
---
*文档随仓库维护;若集成文件列表有增删,请同步更新本节表格。*
vllm/kvprune_legacy_save/layers/__init__.py
deleted
100644 → 0
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layers from upstream compactor (attention, linear, MoE, …).
Prefer importing concrete modules, e.g. ``from vllm.kvprune.layers.attention import ...``.
"""
__all__
:
list
[
str
]
=
[]
vllm/kvprune_legacy_save/layers/activation.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
SiluAndMul
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
# @torch.compile
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
y
=
x
.
chunk
(
2
,
-
1
)
return
F
.
silu
(
x
)
*
y
vllm/kvprune_legacy_save/layers/attention.py
deleted
100644 → 0
View file @
2b7160c6
from
typing
import
Optional
import
torch
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
from
torch
import
nn
from
vllm.kvprune.attention.fa_paged_bridge
import
(
flash_decode_from_paged
,
flash_prefill_from_paged
,
)
from
vllm.kvprune.attention.sparse_decode_kernel
import
head_sparse_decode_attention
from
vllm.kvprune.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
from
vllm.kvprune.compression.common
import
extract_and_store_top_kv
from
vllm.kvprune.config.engine_config
import
KvpruneAttentionSchedule
from
vllm.kvprune.kv_cache.store_kv_cache
import
decode_store_kv
,
prefill_store_all_kv
from
vllm.kvprune.utils.context
import
Context
,
get_context
from
vllm.kvprune.utils.helpers
import
maybe_execute_in_stream
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scale
,
num_kv_heads
,
):
super
().
__init__
()
self
.
num_heads
:
int
=
num_heads
self
.
head_dim
=
head_dim
self
.
scale
:
float
=
scale
self
.
num_kv_heads
=
int
(
num_kv_heads
)
self
.
k_cache
:
Optional
[
torch
.
Tensor
]
=
None
self
.
v_cache
:
Optional
[
torch
.
Tensor
]
=
None
self
.
page_table
:
Optional
[
torch
.
Tensor
]
=
None
self
.
bh_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
page_size
:
Optional
[
int
]
=
None
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scores
:
Optional
[
torch
.
Tensor
]
=
None
,
):
context
:
Context
=
get_context
()
batch_mapping
=
context
.
batch_mapping
seq_lens
=
(
None
if
self
.
bh_seq_lens
is
None
else
self
.
bh_seq_lens
.
index_select
(
0
,
batch_mapping
).
contiguous
()
)
sched
=
context
.
attention_schedule
use_triton_prefill_attn
=
(
sched
==
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
)
use_fa_decode
=
sched
==
KvpruneAttentionSchedule
.
PDFA
if
context
.
is_prefill
:
seq_lens_copy
=
seq_lens
.
clone
()
if
seq_lens
is
not
None
else
None
if
(
self
.
k_cache
is
not
None
and
context
.
do_compression
and
scores
is
not
None
):
compression_context
=
context
.
compression_context
assert
scores
is
not
None
assert
compression_context
is
not
None
maybe_execute_in_stream
(
extract_and_store_top_kv
,
scores
=
scores
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_k_len
=
context
.
max_seqlen_k
,
top_k
=
compression_context
.
max_tokens_to_retain
,
H
=
int
(
self
.
num_kv_heads
),
new_keys
=
k
,
new_vals
=
v
,
num_tokens_to_retain
=
compression_context
.
batch_tokens_to_retain
,
page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
PAGE_SIZE
=
self
.
page_size
,
PAD_TO_PAGE_SIZE
=
True
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
elif
self
.
k_cache
is
not
None
:
maybe_execute_in_stream
(
prefill_store_all_kv
,
new_keys
=
k
,
new_values
=
v
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_seqlen_k
=
context
.
max_seqlen_k
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
page_table
=
self
.
page_table
,
bh_lens
=
seq_lens
,
batch_mapping
=
batch_mapping
,
PAGE_SIZE
=
self
.
page_size
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
if
use_triton_prefill_attn
:
if
context
.
do_compression
and
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
assert
seq_lens_copy
is
not
None
o
=
causal_sparse_varlen_with_cache
(
q
,
k
,
v
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh
=
seq_lens_copy
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_q
=
context
.
max_seqlen_q
,
max_seqlen_k_cache
=
context
.
max_bh_len
,
HKV
=
int
(
self
.
num_kv_heads
),
PAGE_SIZE
=
self
.
page_size
,
sm_scale
=
self
.
scale
,
)
elif
context
.
do_compression
:
if
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
assert
seq_lens_copy
is
not
None
o
=
flash_prefill_from_paged
(
q
,
k
,
v
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh_before
=
seq_lens_copy
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_q
=
context
.
max_seqlen_q
,
PAGE_SIZE
=
self
.
page_size
,
HKV
=
int
(
self
.
num_kv_heads
),
sm_scale
=
self
.
scale
,
)
else
:
o
=
flash_attn_varlen_func
(
q
,
k
,
v
,
max_seqlen_q
=
context
.
max_seqlen_q
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_k
=
context
.
max_seqlen_k
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
assert
self
.
k_cache
is
not
None
,
"KV Cache must be initialized for decoding"
decode_store_kv
(
key
=
k
,
value
=
v
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
page_table
=
self
.
page_table
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
PAGE_SIZE
=
self
.
page_size
,
)
if
use_fa_decode
:
assert
seq_lens
is
not
None
o
=
flash_decode_from_paged
(
q
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh
=
seq_lens
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
PAGE_SIZE
=
self
.
page_size
,
HKV
=
int
(
self
.
num_kv_heads
),
sm_scale
=
self
.
scale
,
)
else
:
o
=
head_sparse_decode_attention
(
q
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens
,
self
.
page_table
,
batch_mapping
,
int
(
self
.
num_kv_heads
),
self
.
page_size
,
self
.
scale
,
key_split
=
context
.
key_split
,
)
# Match compactor_vllm ``Attention``: ``index_copy_`` into the global
# ``bh_seq_lens`` table. The Triton masked copy was a CUDA fast path but
# disagreed with decode_store_kv / paged attention bookkeeping in edge
# cases and could leave lengths stale → garbage logits / immediate EOS.
if
self
.
bh_seq_lens
is
not
None
:
longbm
=
batch_mapping
.
to
(
device
=
self
.
bh_seq_lens
.
device
,
dtype
=
torch
.
long
)
maybe_execute_in_stream
(
self
.
bh_seq_lens
.
index_copy_
,
0
,
longbm
,
seq_lens
,
STORE_STREAM
=
context
.
STORE_STREAM
if
context
.
is_prefill
else
None
,
)
return
o
vllm/kvprune_legacy_save/layers/embed_head.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
vllm.kvprune.utils.context
import
get_context
from
vllm.kvprune.utils.tp_collectives
import
tensor_parallel_all_reduce
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
from
torch
import
nn
class
VocabParallelEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
):
super
().
__init__
()
self
.
tp_rank
=
tensor_parallel_rank_for_sharding
()
self
.
tp_size
=
tensor_parallel_world_size_for_sharding
()
assert
num_embeddings
%
self
.
tp_size
==
0
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings_per_partition
=
self
.
num_embeddings
//
self
.
tp_size
self
.
vocab_start_idx
=
self
.
num_embeddings_per_partition
*
self
.
tp_rank
self
.
vocab_end_idx
=
self
.
vocab_start_idx
+
self
.
num_embeddings_per_partition
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
embedding_dim
)
)
self
.
weight
.
weight_loader
=
self
.
weight_loader
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
0
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
0
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
mask
=
(
x
>=
self
.
vocab_start_idx
)
&
(
x
<
self
.
vocab_end_idx
)
x
=
mask
*
(
x
-
self
.
vocab_start_idx
)
y
=
F
.
embedding
(
x
,
self
.
weight
)
if
self
.
tp_size
>
1
:
y
=
mask
.
unsqueeze
(
1
)
*
y
tensor_parallel_all_reduce
(
y
)
return
y
class
ParallelLMHead
(
VocabParallelEmbedding
):
"""LM head with TP vocab sharding.
When embedded in a vLLM worker, logits must be gathered on the **tensor-
parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`),
not the default :func:`torch.distributed.gather` — otherwise shard order / group
mismatch yields garbage logits and decoded gibberish.
After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab),
matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
removal of padded vocabulary columns.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
bias
:
bool
=
False
,
*
,
org_vocab_size
:
int
|
None
=
None
,
):
assert
not
bias
super
().
__init__
(
num_embeddings
,
embedding_dim
)
# Original (unpadded) vocab size for logits truncation; defaults to num_embeddings.
self
.
org_vocab_size
=
(
int
(
org_vocab_size
)
if
org_vocab_size
is
not
None
else
num_embeddings
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
context
=
get_context
()
if
context
.
is_prefill
:
cu
=
context
.
cu_seqlens_q
last_indices
=
(
cu
[
1
:]
-
1
).
to
(
torch
.
long
)
n_tok
=
x
.
shape
[
0
]
if
n_tok
>
0
:
last_indices
=
last_indices
.
clamp
(
min
=
0
,
max
=
n_tok
-
1
)
x
=
x
[
last_indices
].
contiguous
()
logits
=
F
.
linear
(
x
,
self
.
weight
)
if
self
.
tp_size
>
1
:
logits
=
self
.
_gather_logits_tp
(
logits
)
if
logits
is
not
None
and
logits
.
shape
[
-
1
]
>
self
.
org_vocab_size
:
logits
=
logits
[...,
:
self
.
org_vocab_size
]
return
logits
def
_gather_logits_tp
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
try
:
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_gather
,
)
if
model_parallel_is_initialized
():
return
tensor_model_parallel_gather
(
logits
,
dst
=
0
,
dim
=-
1
)
except
Exception
:
pass
all_logits
=
(
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
tp_size
)]
if
self
.
tp_rank
==
0
else
None
)
dist
.
gather
(
logits
,
all_logits
,
0
)
return
torch
.
cat
(
all_logits
,
-
1
)
if
self
.
tp_rank
==
0
else
None
vllm/kvprune_legacy_save/layers/layernorm.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
from
torch
import
nn
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
# @torch.compile
def
rms_forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
.
mul_
(
torch
.
rsqrt
(
var
+
self
.
eps
))
x
=
x
.
to
(
orig_dtype
).
mul_
(
self
.
weight
)
return
x
# @torch.compile
def
add_rms_forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_dtype
=
x
.
dtype
x
=
x
.
float
().
add_
(
residual
.
float
())
residual
=
x
.
to
(
orig_dtype
)
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
.
mul_
(
torch
.
rsqrt
(
var
+
self
.
eps
))
x
=
x
.
to
(
orig_dtype
).
mul_
(
self
.
weight
)
return
x
,
residual
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
return
self
.
rms_forward
(
x
)
else
:
return
self
.
add_rms_forward
(
x
,
residual
)
vllm/kvprune_legacy_save/layers/linear.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
vllm.kvprune.utils.tp_collectives
import
tensor_parallel_all_reduce
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
from
torch
import
nn
def
divide
(
numerator
,
denominator
):
assert
numerator
%
denominator
==
0
return
numerator
//
denominator
class
LinearBase
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
tp_dim
:
int
|
None
=
None
,
):
super
().
__init__
()
self
.
tp_dim
=
tp_dim
self
.
tp_rank
=
tensor_parallel_rank_for_sharding
()
self
.
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
output_size
,
input_size
))
self
.
weight
.
weight_loader
=
self
.
weight_loader
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_size
))
self
.
bias
.
weight_loader
=
self
.
weight_loader
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
super
().
__init__
(
input_size
,
output_size
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
ColumnParallelLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
tp_size
=
tensor_parallel_world_size_for_sharding
()
super
().
__init__
(
input_size
,
divide
(
output_size
,
tp_size
),
bias
,
0
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
self
.
tp_dim
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
self
.
tp_dim
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
list
[
int
],
bias
:
bool
=
False
,
):
self
.
output_sizes
=
output_sizes
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
):
param_data
=
param
.
data
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
self
.
tp_dim
)[
self
.
tp_rank
]
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
int
|
None
=
None
,
bias
:
bool
=
False
,
):
tp_size
=
tensor_parallel_world_size_for_sharding
()
total_num_kv_heads
=
total_num_kv_heads
or
total_num_heads
self
.
head_size
=
head_size
self
.
num_heads
=
divide
(
total_num_heads
,
tp_size
)
self
.
num_kv_heads
=
divide
(
total_num_kv_heads
,
tp_size
)
output_size
=
(
total_num_heads
+
2
*
total_num_kv_heads
)
*
self
.
head_size
super
().
__init__
(
hidden_size
,
output_size
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
str
):
param_data
=
param
.
data
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
if
loaded_shard_id
==
"q"
:
shard_size
=
self
.
num_heads
*
self
.
head_size
shard_offset
=
0
elif
loaded_shard_id
==
"k"
:
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_offset
=
self
.
num_heads
*
self
.
head_size
else
:
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_offset
=
(
self
.
num_heads
*
self
.
head_size
+
self
.
num_kv_heads
*
self
.
head_size
)
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
self
.
tp_dim
)[
self
.
tp_rank
]
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
tp_size
=
tensor_parallel_world_size_for_sharding
()
super
().
__init__
(
divide
(
input_size
,
tp_size
),
output_size
,
bias
,
1
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
self
.
tp_dim
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
self
.
tp_dim
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
y
=
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
if
self
.
tp_rank
==
0
else
None
)
if
self
.
tp_size
>
1
:
tensor_parallel_all_reduce
(
y
)
return
y
vllm/kvprune_legacy_save/layers/moe.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
from
vllm.kvprune.triton_kernels.matmul_ogs
import
matmul_ogs
from
vllm.kvprune.utils.tp_collectives
import
tensor_parallel_all_reduce
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
from
torch
import
nn
def
divide
(
numerator
,
denominator
):
assert
numerator
%
denominator
==
0
return
numerator
//
denominator
class
TritonFusedMoeLinearBase
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
tp_dim
:
int
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
tp_dim
=
tp_dim
self
.
tp_rank
=
tensor_parallel_rank_for_sharding
()
self
.
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
num_experts
=
num_experts
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
in_features
,
out_features
)).
transpose
(
-
1
,
-
2
)
)
self
.
weight
.
weight_loader
=
self
.
weight_loader
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
out_features
)))
self
.
bias
.
weight_loader
=
self
.
weight_loader
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedTritonFusedMoeLinear
(
TritonFusedMoeLinearBase
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
in_features
,
out_features
,
num_experts
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
):
param
.
data
[
expert_idx
].
copy_
(
loaded_weight
,
non_blocking
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
w
=
self
.
weight
.
transpose
(
-
1
,
-
2
)
assert
w
.
is_contiguous
()
return
matmul_ogs
(
x
,
self
.
weight
,
self
.
bias
,
**
kwargs
,
)
class
RowParallelTritonFusedMoeLinear
(
TritonFusedMoeLinearBase
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
)
->
None
:
tp_size
=
(
tensor_parallel_world_size_for_sharding
()
if
dist
.
is_initialized
()
else
1
)
super
().
__init__
(
divide
(
in_features
,
tp_size
),
out_features
,
num_experts
,
bias
,
2
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
):
shard_size
=
param
.
size
(
2
)
start_idx
=
self
.
tp_rank
*
shard_size
local_shard
=
loaded_weight
[:,
start_idx
:
start_idx
+
shard_size
]
param
.
data
[
expert_idx
].
copy_
(
local_shard
,
non_blocking
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
w
=
self
.
weight
.
transpose
(
-
1
,
-
2
)
assert
w
.
is_contiguous
()
y
=
matmul_ogs
(
x
,
w
,
self
.
bias
,
**
kwargs
,
)
if
self
.
tp_size
>
1
:
tensor_parallel_all_reduce
(
y
)
return
y
class
ColumnParallelTritonFusedMoeLinear
(
TritonFusedMoeLinearBase
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
)
->
None
:
tp_size
=
(
tensor_parallel_world_size_for_sharding
()
if
dist
.
is_initialized
()
else
1
)
super
().
__init__
(
in_features
,
divide
(
out_features
,
tp_size
),
num_experts
,
bias
,
1
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
):
shard_size
=
param
.
size
(
1
)
start_idx
=
self
.
tp_rank
*
shard_size
local_shard
=
loaded_weight
[
start_idx
:
start_idx
+
shard_size
,
:]
param
.
data
[
expert_idx
].
copy_
(
local_shard
,
non_blocking
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
w
=
self
.
weight
.
transpose
(
-
1
,
-
2
)
assert
w
.
is_contiguous
()
y
=
matmul_ogs
(
x
,
w
,
self
.
bias
,
**
kwargs
,
)
return
y
class
MergedColumnParallelTritonFusedMoeLinear
(
ColumnParallelTritonFusedMoeLinear
):
def
__init__
(
self
,
in_features
:
int
,
out_feature_list
:
list
[
int
],
num_experts
:
int
,
bias
:
bool
=
False
,
):
self
.
out_feature_list
=
out_feature_list
super
().
__init__
(
in_features
,
sum
(
out_feature_list
),
num_experts
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
,
shard_id
:
int
,
):
param_data
=
param
.
data
shard_offset
=
sum
(
self
.
out_feature_list
[:
shard_id
])
//
self
.
tp_size
shard_size
=
self
.
out_feature_list
[
shard_id
]
//
self
.
tp_size
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
local_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
dim
=
self
.
tp_dim
-
1
)[
self
.
tp_rank
]
param_data
[
expert_idx
].
copy_
(
local_weight
,
non_blocking
=
True
)
vllm/kvprune_legacy_save/layers/rotary_embedding.py
deleted
100644 → 0
View file @
2b7160c6
import
math
from
functools
import
lru_cache
from
typing
import
Any
import
torch
from
torch
import
nn
def
apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x1
,
x2
=
torch
.
chunk
(
x
.
float
(),
2
,
dim
=-
1
)
y1
=
x1
*
cos
-
x2
*
sin
y2
=
x2
*
cos
+
x1
*
sin
return
torch
.
cat
((
y1
,
y2
),
dim
=-
1
).
to
(
x
.
dtype
)
def
rope_theta_from_hf_config
(
config
:
Any
)
->
float
:
"""Match vLLM/HF: ``rope_theta`` may live only under ``rope_parameters`` in config.json."""
rp
=
getattr
(
config
,
"rope_parameters"
,
None
)
if
isinstance
(
rp
,
dict
)
and
"rope_theta"
in
rp
:
return
float
(
rp
[
"rope_theta"
])
return
float
(
getattr
(
config
,
"rope_theta"
,
1_000_000.0
))
class
RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
rope_scaling
:
tuple
|
None
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
rotary_dim
)
)
if
rope_scaling
is
not
None
:
(
rope_type
,
factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position_embeddings
,
)
=
rope_scaling
assert
rope_type
==
"llama3"
old_context_len
=
original_max_position_embeddings
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
wavelen
=
2
*
math
.
pi
/
inv_freq
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
).
unsqueeze_
(
1
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
# @torch.compile
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cache_len
=
self
.
cos_sin_cache
.
shape
[
0
]
# CUDA graph capture forbids device→CPU sync (e.g. ``.item()``) inside the
# captured region; :meth:`ModelRunner.capture_cudagraph` runs decode with
# placeholder positions. Skip the range check while capturing; eager runs
# still validate.
_capturing
=
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
is_current_stream_capturing
()
)
if
positions
.
numel
()
>
0
and
not
_capturing
:
pmax
=
int
(
positions
.
max
().
item
())
pmin
=
int
(
positions
.
min
().
item
())
if
pmax
>=
cache_len
or
pmin
<
0
:
raise
ValueError
(
f
"RoPE positions out of range: need 0 <= pos <
{
cache_len
}
, "
f
"got min=
{
pmin
}
, max=
{
pmax
}
. "
"Shorten the prompt or increase max_model_len (and align vLLM "
"RoPE cos_sin_cache with tie_kvprune_rope_buffers_from_vllm)."
)
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query
=
apply_rotary_emb
(
query
,
cos
,
sin
)
key
=
apply_rotary_emb
(
key
,
cos
,
sin
)
return
query
,
key
@
lru_cache
(
1
)
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
float
,
rope_scaling
:
tuple
|
None
=
None
,
):
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rope_scaling
)
return
rotary_emb
vllm/kvprune_legacy_save/layers/sampler.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
from
torch
import
nn
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
# @torch.compile
def
forward
(
self
,
logits
:
torch
.
Tensor
,
temperatures
:
torch
.
Tensor
):
temps
=
temperatures
.
view
(
-
1
)
scaled
=
logits
.
float
()
greedy_mask
=
temps
==
0.0
sample_mask
=
~
greedy_mask
if
sample_mask
.
any
():
temps_sample
=
temps
[
sample_mask
].
unsqueeze
(
-
1
)
# [B_sample, 1]
scaled_sample
=
scaled
[
sample_mask
].
div
(
temps_sample
)
# temperature scaling
E
=
torch
.
empty_like
(
scaled_sample
).
exponential_
(
1
).
clamp_min_
(
1e-10
).
log
()
scaled_sample
=
scaled_sample
-
E
scaled
=
scaled
.
clone
()
scaled
[
sample_mask
]
=
scaled_sample
return
scaled
.
argmax
(
dim
=-
1
)
vllm/kvprune_legacy_save/layers/triton_helpers.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_masked_index_select_kernel
(
X_ptr
,
IDX_ptr
,
OUT_ptr
,
N
,
stride_xn
,
stride_xh
,
stride_ob
,
stride_oh
,
):
b
=
tl
.
program_id
(
0
)
# which output row (0..B-1)
h
=
tl
.
program_id
(
1
)
idx
=
tl
.
load
(
IDX_ptr
+
b
)
# int32
valid
=
(
idx
>=
0
)
&
(
idx
<
N
)
out_ptrs
=
OUT_ptr
+
b
*
stride_ob
+
h
*
stride_oh
if
not
valid
:
tl
.
store
(
out_ptrs
,
0
)
else
:
x_ptrs
=
X_ptr
+
idx
*
stride_xn
+
h
*
stride_xh
vals
=
tl
.
load
(
x_ptrs
)
tl
.
store
(
out_ptrs
,
vals
)
def
masked_index_select_triton_dim0
(
input
:
torch
.
Tensor
,
index
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
X: [N, H] : contiguous in the H dimension
b_m: [B] int32/int64 on same device; out-of-range -> zeros)
Returns: [B, H]
"""
assert
input
.
ndim
==
2
and
index
.
ndim
==
1
N
,
H
=
input
.
shape
B
=
index
.
numel
()
out
=
torch
.
empty
((
B
,
H
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
_masked_index_select_kernel
[(
B
,
H
)](
input
,
index
,
out
,
N
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
)
return
out
@
triton
.
jit
def
_masked_index_copy_kernel
(
DST_ptr
,
IDX_ptr
,
SRC_ptr
,
N
,
stride_dn
,
stride_dh
,
stride_sb
,
stride_sh
,
):
b
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
idx
=
tl
.
load
(
IDX_ptr
+
b
)
valid
=
(
idx
>=
0
)
&
(
idx
<
N
)
if
valid
:
src_ptrs
=
SRC_ptr
+
b
*
stride_sb
+
h
*
stride_sh
dst_ptrs
=
DST_ptr
+
idx
*
stride_dn
+
h
*
stride_dh
tl
.
store
(
dst_ptrs
,
tl
.
load
(
src_ptrs
))
def
masked_index_copy_triton_dim0
(
dst
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
src
:
torch
.
Tensor
):
"""
In-place: dst.index_copy_(0, index, src) but masked:
- rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
Shapes:
dst: [N, H]
src: [B, H]
index: [B]
"""
assert
dst
.
ndim
==
2
and
src
.
ndim
==
2
and
index
.
ndim
==
1
N
,
H
=
dst
.
shape
B
,
Hs
=
src
.
shape
assert
Hs
==
H
and
index
.
numel
()
==
B
_masked_index_copy_kernel
[(
B
,
H
)](
dst
,
index
,
src
,
N
,
dst
.
stride
(
0
),
dst
.
stride
(
1
),
src
.
stride
(
0
),
src
.
stride
(
1
),
)
vllm/kvprune_legacy_save/models/__init__.py
deleted
100644 → 0
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
logging
from
vllm.kvprune.models.llama3
import
LlamaForCausalLM
from
vllm.kvprune.models.qwen3
import
Qwen3ForCausalLM
logger
=
logging
.
getLogger
(
__name__
)
MODEL_REGISTRY
=
{
"llama"
:
LlamaForCausalLM
,
"qwen3"
:
Qwen3ForCausalLM
,
}
try
:
from
vllm.kvprune.models.qwen3_moe
import
Qwen3MoeForCausalLM
except
Exception
as
exc
:
logger
.
warning
(
"Disabling qwen3_moe due to import error: %s"
,
exc
)
else
:
MODEL_REGISTRY
[
"qwen3_moe"
]
=
Qwen3MoeForCausalLM
vllm/kvprune_legacy_save/models/llama3.py
deleted
100644 → 0
View file @
2b7160c6
import
os
from
glob
import
glob
import
torch
import
tqdm
from
safetensors
import
safe_open
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.kvprune.compression
import
(
CompressionMethod
,
apply_postrope_compression
,
apply_prerope_compression
,
)
from
vllm.kvprune.layers.activation
import
SiluAndMul
from
vllm.kvprune.layers.attention
import
Attention
from
vllm.kvprune.layers.embed_head
import
ParallelLMHead
,
VocabParallelEmbedding
from
vllm.kvprune.layers.layernorm
import
RMSNorm
from
vllm.kvprune.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.kvprune.layers.rotary_embedding
import
get_rope
from
vllm.kvprune.utils.context
import
get_context
from
vllm.kvprune.utils.tp_utils
import
tensor_parallel_world_size_for_sharding
class
LlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
head_dim
:
int
|
None
=
None
,
qkv_bias
:
bool
=
False
,
rope_theta
:
float
=
10000
,
rope_scaling
:
dict
|
None
=
None
,
)
->
None
:
super
().
__init__
()
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
assert
self
.
total_num_kv_heads
%
tp_size
==
0
self
.
num_kv_heads
=
self
.
total_num_kv_heads
//
tp_size
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
)
if
rope_scaling
is
not
None
:
rope_scaling_tuple
=
(
rope_scaling
[
"rope_type"
],
rope_scaling
[
"factor"
],
rope_scaling
[
"low_freq_factor"
],
rope_scaling
[
"high_freq_factor"
],
rope_scaling
[
"original_max_position_embeddings"
],
)
else
:
rope_scaling_tuple
=
None
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling_tuple
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_kv_heads
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
context
=
get_context
()
qkv
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
scores
=
None
if
context
.
is_prefill
and
context
.
do_compression
:
scores
=
apply_prerope_compression
(
q
,
k
,
v
,
context
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
context
.
is_prefill
and
context
.
do_compression
:
cc
=
context
.
compression_context
if
cc
is
not
None
and
cc
.
compression_method
==
CompressionMethod
.
CRITICALADAKV
:
# 关键:注入 wo_weight 到 compression_context
wo_raw
=
self
.
o_proj
.
weight
hidden_size
,
_
=
wo_raw
.
shape
Hq
,
D
=
self
.
num_heads
,
self
.
head_dim
cc
.
wo_weight
=
(
wo_raw
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
Hq
,
D
,
hidden_size
)
.
to
(
dtype
=
torch
.
float32
)
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
output
=
self
.
o_proj
(
o
.
flatten
(
1
,
-
1
))
return
output
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
mlp_bias
:
bool
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
mlp_bias
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
mlp_bias
,
)
assert
hidden_act
==
"silu"
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
down_proj
(
x
)
return
x
class
LlamaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
LlamaAttention
(
hidden_size
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
max_position
=
config
.
max_position_embeddings
,
qkv_bias
=
getattr
(
config
,
"attention_bias"
,
False
),
head_dim
=
getattr
(
config
,
"head_dim"
,
None
),
rope_theta
=
getattr
(
config
,
"rope_theta"
,
500000.0
),
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
),
)
self
.
mlp
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
mlp_bias
=
config
.
mlp_bias
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
),
hidden_states
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
LlamaForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q"
),
"k_proj"
:
(
"qkv_proj"
,
"k"
),
"v_proj"
:
(
"qkv_proj"
,
"v"
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_vocab_size
=
config
.
vocab_size
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
.
data
=
self
.
model
.
embed_tokens
.
weight
.
data
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
lm_head
(
hidden_states
)
def
load_model
(
self
,
path
:
str
,
*
,
use_tqdm
:
bool
=
False
,
)
->
None
:
all_shards
=
glob
(
os
.
path
.
join
(
path
,
"*.safetensors"
))
for
file
in
(
tqdm
.
tqdm
(
all_shards
,
desc
=
"Loading model"
)
if
use_tqdm
else
all_shards
):
with
safe_open
(
file
,
"pt"
,
"cpu"
)
as
f
:
for
weight_name
in
f
.
keys
():
weight_tensor
=
f
.
get_tensor
(
weight_name
)
is_loaded
=
False
# Load packed modules
for
k
in
self
.
packed_modules_mapping
:
if
k
in
weight_name
:
v
,
shard_id
=
self
.
packed_modules_mapping
[
k
]
param_name
=
weight_name
.
replace
(
k
,
v
)
param
=
self
.
get_parameter
(
param_name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
)
weight_loader
(
param
,
weight_tensor
,
shard_id
)
is_loaded
=
True
break
# Load other modules
if
not
is_loaded
:
param
=
self
.
get_parameter
(
weight_name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
lambda
p
,
loaded_weight
:
p
.
data
.
copy_
(
loaded_weight
),
)
weight_loader
(
param
,
weight_tensor
)
is_loaded
=
True
assert
is_loaded
,
f
"Weight
{
weight_name
}
not loaded"
vllm/kvprune_legacy_save/models/qwen3.py
deleted
100644 → 0
View file @
2b7160c6
import
os
from
glob
import
glob
import
torch
import
tqdm
from
safetensors
import
safe_open
from
torch
import
nn
from
transformers
import
Qwen3Config
from
vllm.kvprune.compression
import
(
CompressionMethod
,
apply_postrope_compression
,
apply_prerope_compression
,
)
from
vllm.kvprune.layers.activation
import
SiluAndMul
from
vllm.kvprune.layers.attention
import
Attention
from
vllm.kvprune.layers.embed_head
import
ParallelLMHead
,
VocabParallelEmbedding
from
vllm.kvprune.layers.layernorm
import
RMSNorm
from
vllm.kvprune.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.kvprune.layers.rotary_embedding
import
get_rope
,
rope_theta_from_hf_config
from
vllm.kvprune.utils.context
import
get_context
from
vllm.kvprune.utils.tp_utils
import
tensor_parallel_world_size_for_sharding
class
Qwen3Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
head_dim
:
int
|
None
=
None
,
rms_norm_eps
:
float
=
1e-06
,
qkv_bias
:
bool
=
False
,
rope_theta
:
float
=
10000
,
rope_scaling
:
tuple
|
None
=
None
,
)
->
None
:
super
().
__init__
()
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
assert
self
.
total_num_kv_heads
%
tp_size
==
0
self
.
num_kv_heads
=
self
.
total_num_kv_heads
//
tp_size
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_kv_heads
,
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
context
=
get_context
()
qkv
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
q_norm
(
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
))
k
=
self
.
k_norm
(
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
))
scores
=
None
if
context
.
is_prefill
and
context
.
do_compression
:
scores
=
apply_prerope_compression
(
q
,
k
,
v
,
context
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
context
.
is_prefill
and
context
.
do_compression
:
cc
=
context
.
compression_context
if
cc
is
not
None
and
cc
.
compression_method
==
CompressionMethod
.
CRITICALADAKV
:
# 关键:注入 wo_weight 到 compression_context
wo_raw
=
self
.
o_proj
.
weight
hidden_size
,
_
=
wo_raw
.
shape
Hq
,
D
=
self
.
num_heads
,
self
.
head_dim
cc
.
wo_weight
=
(
wo_raw
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
Hq
,
D
,
hidden_size
)
.
to
(
dtype
=
torch
.
float32
)
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
output
=
self
.
o_proj
(
o
.
flatten
(
1
,
-
1
))
return
output
class
Qwen3MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
)
assert
hidden_act
==
"silu"
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
down_proj
(
x
)
return
x
class
Qwen3DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen3Config
,
)
->
None
:
super
().
__init__
()
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
head_dim
is
None
:
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rope_theta
=
rope_theta_from_hf_config
(
config
)
rs
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling_tuple
:
tuple
|
None
=
rs
if
isinstance
(
rs
,
tuple
)
else
None
self
.
self_attn
=
Qwen3Attention
(
hidden_size
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
max_position
=
config
.
max_position_embeddings
,
rms_norm_eps
=
config
.
rms_norm_eps
,
qkv_bias
=
getattr
(
config
,
"attention_bias"
,
False
),
head_dim
=
head_dim
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling_tuple
,
)
self
.
mlp
=
Qwen3MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
),
hidden_states
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
Qwen3Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen3Config
,
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
Qwen3DecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
layer
in
self
.
layers
:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Qwen3ForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q"
),
"k_proj"
:
(
"qkv_proj"
,
"k"
),
"v_proj"
:
(
"qkv_proj"
,
"v"
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
:
Qwen3Config
)
->
None
:
super
().
__init__
()
self
.
model
=
Qwen3Model
(
config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_vocab_size
=
config
.
vocab_size
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
.
data
=
self
.
model
.
embed_tokens
.
weight
.
data
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
lm_head
(
hidden_states
)
def
load_model
(
self
,
path
:
str
,
*
,
use_tqdm
:
bool
=
False
,
)
->
None
:
all_shards
=
glob
(
os
.
path
.
join
(
path
,
"*.safetensors"
))
for
file
in
(
tqdm
.
tqdm
(
all_shards
,
desc
=
"Loading model"
)
if
use_tqdm
else
all_shards
):
with
safe_open
(
file
,
"pt"
,
"cpu"
)
as
f
:
for
weight_name
in
f
.
keys
():
weight_tensor
=
f
.
get_tensor
(
weight_name
)
is_loaded
=
False
# Load packed modules
for
k
in
self
.
packed_modules_mapping
:
if
k
in
weight_name
:
v
,
shard_id
=
self
.
packed_modules_mapping
[
k
]
param_name
=
weight_name
.
replace
(
k
,
v
)
param
=
self
.
get_parameter
(
param_name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
)
weight_loader
(
param
,
weight_tensor
,
shard_id
)
is_loaded
=
True
break
# Load other modules
if
not
is_loaded
:
param
=
self
.
get_parameter
(
weight_name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
lambda
p
,
loaded_weight
:
p
.
data
.
copy_
(
loaded_weight
),
)
weight_loader
(
param
,
weight_tensor
)
is_loaded
=
True
assert
is_loaded
,
f
"Weight
{
weight_name
}
not loaded"
vllm/kvprune_legacy_save/models/qwen3_moe.py
deleted
100644 → 0
View file @
2b7160c6
import
os
from
glob
import
glob
import
torch
import
tqdm
from
safetensors
import
safe_open
from
torch
import
nn
from
transformers
import
Qwen3MoeConfig
from
vllm.kvprune.compression
import
(
CompressionMethod
,
apply_postrope_compression
,
apply_prerope_compression
,
)
from
vllm.kvprune.layers.activation
import
SiluAndMul
from
vllm.kvprune.layers.attention
import
Attention
from
vllm.kvprune.layers.embed_head
import
ParallelLMHead
,
VocabParallelEmbedding
from
vllm.kvprune.layers.layernorm
import
RMSNorm
from
vllm.kvprune.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.kvprune.layers.moe
import
(
MergedColumnParallelTritonFusedMoeLinear
,
RowParallelTritonFusedMoeLinear
,
)
from
vllm.kvprune.layers.rotary_embedding
import
get_rope
,
rope_theta_from_hf_config
from
vllm.kvprune.triton_kernels.routing
import
routing
from
vllm.kvprune.utils.context
import
get_context
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
class
Qwen3MoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
head_dim
:
int
|
None
=
None
,
rms_norm_eps
:
float
=
1e-06
,
qkv_bias
:
bool
=
False
,
rope_theta
:
float
=
10000
,
rope_scaling
:
tuple
|
None
=
None
,
sliding_window
:
int
|
None
=
None
,
)
->
None
:
super
().
__init__
()
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
assert
self
.
total_num_kv_heads
%
tp_size
==
0
self
.
num_kv_heads
=
self
.
total_num_kv_heads
//
tp_size
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
sliding_window
=
sliding_window
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_kv_heads
,
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
context
=
get_context
()
qkv
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
q_norm
(
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
))
k
=
self
.
k_norm
(
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
))
scores
=
None
if
context
.
is_prefill
and
context
.
do_compression
:
scores
=
apply_prerope_compression
(
q
,
k
,
v
,
context
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
context
.
is_prefill
and
context
.
do_compression
:
cc
=
context
.
compression_context
if
cc
is
not
None
and
cc
.
compression_method
==
CompressionMethod
.
CRITICALADAKV
:
# 关键:注入 wo_weight 到 compression_context
wo_raw
=
self
.
o_proj
.
weight
hidden_size
,
_
=
wo_raw
.
shape
Hq
,
D
=
self
.
num_heads
,
self
.
head_dim
cc
.
wo_weight
=
(
wo_raw
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
Hq
,
D
,
hidden_size
)
.
to
(
dtype
=
torch
.
float32
)
)
scores
=
apply_postrope_compression
(
q
,
k
,
v
,
scores
,
context
)
o
=
self
.
attn
(
q
,
k
,
v
,
scores
)
output
=
self
.
o_proj
(
o
.
flatten
(
1
,
-
1
))
return
output
class
Qwen3MoeMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
)
assert
hidden_act
==
"silu"
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
down_proj
(
x
)
return
x
class
Qwen3MoeTritonSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
num_experts_per_tok
:
int
,
norm_topk_prob
:
bool
,
hidden_act
:
str
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
norm_topk_prob
=
norm_topk_prob
self
.
hidden_size
=
hidden_size
self
.
moe_intermediate_size
=
intermediate_size
self
.
gate
=
ReplicatedLinear
(
hidden_size
,
num_experts
,
bias
=
False
)
self
.
gate_up_proj
=
MergedColumnParallelTritonFusedMoeLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
num_experts
)
self
.
down_proj
=
RowParallelTritonFusedMoeLinear
(
intermediate_size
,
hidden_size
,
num_experts
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
hidden_states
if
x
.
numel
()
==
0
:
return
x
logits
=
self
.
gate
(
x
)
rdata
,
gather_indx
,
scatter_indx
=
routing
(
logits
,
self
.
num_experts_per_tok
,
simulated_ep
=
1
,
# single device, replicated experts
)
x
=
self
.
gate_up_proj
(
x
,
routing_data
=
rdata
,
gather_indx
=
gather_indx
)
x
=
self
.
act_fn
(
x
)
x
=
self
.
down_proj
(
x
,
routing_data
=
rdata
,
scatter_indx
=
scatter_indx
,
gammas
=
rdata
.
gate_scal
)
return
x
class
Qwen3MoeBlock
(
Qwen3MoeTritonSparseMoeBlock
):
pass
class
Qwen3MoeRMSNorm
(
RMSNorm
):
pass
class
Qwen3MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen3MoeConfig
,
layer_idx
:
int
,
)
->
None
:
super
().
__init__
()
head_dim
=
getattr
(
config
,
"head_dim"
,
None
)
if
head_dim
is
None
:
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rope_theta
=
rope_theta_from_hf_config
(
config
)
rs
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling_tuple
:
tuple
|
None
=
rs
if
isinstance
(
rs
,
tuple
)
else
None
self
.
self_attn
=
Qwen3MoeAttention
(
hidden_size
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
max_position
=
config
.
max_position_embeddings
,
head_dim
=
head_dim
,
rms_norm_eps
=
config
.
rms_norm_eps
,
qkv_bias
=
getattr
(
config
,
"attention_bias"
,
False
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling_tuple
,
sliding_window
=
config
.
sliding_window
,
)
if
(
layer_idx
not
in
config
.
mlp_only_layers
)
and
(
config
.
num_experts
>
0
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen3MoeBlock
(
num_experts
=
config
.
num_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
num_experts_per_tok
=
config
.
num_experts_per_tok
,
norm_topk_prob
=
config
.
norm_topk_prob
,
hidden_act
=
config
.
hidden_act
,
)
else
:
self
.
mlp
=
Qwen3MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
Qwen3MoeRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen3MoeRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
Qwen3MoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen3MoeConfig
,
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
Qwen3MoeDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
Qwen3MoeRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
decoder_layer
in
self
.
layers
:
hidden_states
=
decoder_layer
(
hidden_states
,
position_ids
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Qwen3MoeForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"q_proj"
:
(
"qkv_proj"
,
"q"
),
"k_proj"
:
(
"qkv_proj"
,
"k"
),
"v_proj"
:
(
"qkv_proj"
,
"v"
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
:
Qwen3MoeConfig
,
)
->
None
:
super
().
__init__
()
self
.
model
=
Qwen3MoeModel
(
config
)
self
.
num_experts
=
config
.
num_experts
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_vocab_size
=
config
.
vocab_size
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
.
data
=
self
.
model
.
embed_tokens
.
weight
.
data
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
position_ids
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
lm_head
(
hidden_states
)
def
load_model
(
self
,
path
:
str
,
*
,
use_tqdm
:
bool
=
False
,
)
->
None
:
rank
=
tensor_parallel_rank_for_sharding
()
device
=
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
rank
all_shards
=
glob
(
os
.
path
.
join
(
path
,
"*.safetensors"
))
for
file
in
(
tqdm
.
tqdm
(
all_shards
,
desc
=
"Loading model"
)
if
use_tqdm
else
all_shards
):
with
safe_open
(
file
,
"pt"
,
f
"cuda:
{
device
}
"
)
as
f
:
for
weight_name
in
f
.
keys
():
weight_tensor
=
f
.
get_tensor
(
weight_name
)
is_expert
=
"mlp.experts"
in
weight_name
is_loaded
=
False
# Process experts params name
if
is_expert
:
mlp_module_name
,
expert_module_name
=
weight_name
.
split
(
".experts."
)
expert_idx
=
int
(
expert_module_name
.
split
(
"."
)[
0
])
proj_name
=
expert_module_name
.
replace
(
f
"
{
expert_idx
}
."
,
""
)
weight_name
=
f
"
{
mlp_module_name
}
.
{
proj_name
}
"
# Load packed modules
for
k
in
self
.
packed_modules_mapping
:
if
k
in
weight_name
:
v
,
shard_id
=
self
.
packed_modules_mapping
[
k
]
param_name
=
weight_name
.
replace
(
k
,
v
)
param
=
self
.
get_parameter
(
param_name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
)
if
is_expert
:
weight_loader
(
param
,
weight_tensor
,
expert_idx
,
shard_id
)
else
:
weight_loader
(
param
,
weight_tensor
,
shard_id
)
is_loaded
=
True
break
# Load other modules
if
not
is_loaded
:
param
=
self
.
get_parameter
(
weight_name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
lambda
p
,
lw
:
p
.
data
.
copy_
(
lw
,
non_blocking
=
True
),
)
if
is_expert
:
weight_loader
(
param
,
weight_tensor
,
expert_idx
)
else
:
weight_loader
(
param
,
weight_tensor
)
is_loaded
=
True
assert
is_loaded
,
f
"Weight
{
weight_name
}
not loaded"
vllm/kvprune_legacy_save/triton_kernels/__init__.py
deleted
100644 → 0
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Triton kernel utilities (matmul_ogs, MoE, topk, …) plus KV-facing entrypoints.
For KV pruning attention/store, see also ``vllm.kvprune.attention`` and
``vllm.kvprune.kv_cache``.
"""
from
vllm.kvprune.attention.sparse_varlen_kernel
import
causal_sparse_varlen_with_cache
from
vllm.kvprune.kv_cache.store_kv_cache
import
(
decode_store_kv
,
prefill_store_all_kv
,
prefill_store_topk_kv
,
)
__all__
=
[
"causal_sparse_varlen_with_cache"
,
"decode_store_kv"
,
"prefill_store_all_kv"
,
"prefill_store_topk_kv"
,
]
vllm/kvprune_legacy_save/triton_kernels/compaction.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
from
.compaction_details._masked_compaction
import
_masked_compaction
from
.tensor
import
Bitmatrix
def
compaction
(
yv
,
yi
,
bitmask
,
sentinel
=-
1
):
"""
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
Only the elements whose index appears among the active bits of *bitmask*
are kept; the rest are replaced by *sentinel*. Kept elements preserve
their original left-to-right order.
Parameters
----------
yv : torch.Tensor, shape (B, K)
Values tensor.
yi : torch.Tensor, shape (B, K), dtype torch.long
Integer indices (0 ≤ index < 32) associated with *yv*.
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
Per-row mask of active indices. See the in-place version for details.
sentinel : int, default -1
Value written into dropped positions of the returned tensors.
Returns
-------
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
New tensors with the same dtype/device as the inputs.
"""
n_rows
,
n_cols
=
yi
.
shape
ret_yv
=
torch
.
empty_like
(
yv
)
ret_yi
=
torch
.
empty_like
(
yi
)
if
isinstance
(
bitmask
,
Bitmatrix
):
bitmask
=
bitmask
.
storage
.
data
_masked_compaction
[(
n_rows
,)](
yv
,
yi
,
bitmask
,
bitmask
.
stride
(
0
),
bitmask
.
stride
(
1
),
# inputs
ret_yv
,
ret_yi
,
# outputs
sentinel
,
# sentinel
K
=
n_cols
,
# constants
)
return
ret_yv
,
ret_yi
def
compaction_torch
(
yv
:
torch
.
Tensor
,
yi
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
sentinel
=-
1
):
"""
reference implementation of `masked_compact`
"""
B
,
K
=
yi
.
shape
device
=
yi
.
device
# Expand bitmask to a boolean matrix of active bits (B, 32)
w
=
1
<<
torch
.
arange
(
32
,
device
=
device
,
dtype
=
bitmask
.
dtype
)
bits
=
(
bitmask
.
unsqueeze
(
-
1
)
&
w
)
!=
0
mask
=
bits
.
flatten
(
start_dim
=-
2
)
# or bits.reshape(B, -1)
# For every yi element decide whether it should be kept
keep
=
mask
.
gather
(
1
,
yi
.
long
())
# Build a stable permutation that brings all "keep" items forward
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
order
=
(
~
keep
).
to
(
torch
.
int
).
argsort
(
dim
=
1
,
stable
=
True
)
# Re‑order tensors according to above permutation
yi_sorted
=
yi
.
gather
(
1
,
order
)
yv_sorted
=
yv
.
gather
(
1
,
order
)
# fill relevant positions with sentinel
keep_sorted
=
keep
.
gather
(
1
,
order
)
yi_sorted
[
~
keep_sorted
]
=
sentinel
yv_sorted
[
~
keep_sorted
]
=
sentinel
return
yv_sorted
,
yi_sorted
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment