Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b7160c6
Commit
2b7160c6
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.0
parent
fa718036
Changes
305
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2562 additions
and
0 deletions
+2562
-0
vllm/kvprune/integration/compression_params.py
vllm/kvprune/integration/compression_params.py
+52
-0
vllm/kvprune/integration/config_adapter.py
vllm/kvprune/integration/config_adapter.py
+143
-0
vllm/kvprune/integration/v1_tp_runner.py
vllm/kvprune/integration/v1_tp_runner.py
+203
-0
vllm/kvprune/integration/vllm_model_access.py
vllm/kvprune/integration/vllm_model_access.py
+46
-0
vllm/kvprune/integration/weight_tie.py
vllm/kvprune/integration/weight_tie.py
+192
-0
vllm/kvprune/kv_cache/__init__.py
vllm/kvprune/kv_cache/__init__.py
+15
-0
vllm/kvprune/kv_cache/page_table.py
vllm/kvprune/kv_cache/page_table.py
+313
-0
vllm/kvprune/kv_cache/store_kv_cache.py
vllm/kvprune/kv_cache/store_kv_cache.py
+473
-0
vllm/kvprune/kv_cache/write_page_table.py
vllm/kvprune/kv_cache/write_page_table.py
+110
-0
vllm/kvprune/kvprune_to_vllm.md
vllm/kvprune/kvprune_to_vllm.md
+68
-0
vllm/kvprune/layers/__init__.py
vllm/kvprune/layers/__init__.py
+9
-0
vllm/kvprune/layers/activation.py
vllm/kvprune/layers/activation.py
+13
-0
vllm/kvprune/layers/attention.py
vllm/kvprune/layers/attention.py
+208
-0
vllm/kvprune/layers/embed_head.py
vllm/kvprune/layers/embed_head.py
+111
-0
vllm/kvprune/layers/layernorm.py
vllm/kvprune/layers/layernorm.py
+49
-0
vllm/kvprune/layers/linear.py
vllm/kvprune/layers/linear.py
+158
-0
vllm/kvprune/layers/moe.py
vllm/kvprune/layers/moe.py
+177
-0
vllm/kvprune/layers/rotary_embedding.py
vllm/kvprune/layers/rotary_embedding.py
+94
-0
vllm/kvprune/layers/sampler.py
vllm/kvprune/layers/sampler.py
+27
-0
vllm/kvprune/layers/triton_helpers.py
vllm/kvprune/layers/triton_helpers.py
+101
-0
No files found.
Too many changes to show.
To preserve performance only
305 of 305+
files are displayed.
Plain diff
Email patch
vllm/kvprune/integration/compression_params.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Per-request KV compression for :meth:`vllm.LLM.generate` (``compression=`` kwarg)."""
from
__future__
import
annotations
from
dataclasses
import
dataclass
@
dataclass
class
CompressionParams
:
"""Per-prompt compression intent for :meth:`vllm.LLM.generate`.
If **any** prompt in the batch has ``compression_ratio < 1.0``, the **whole** batch
is run on the compactor ``LLMEngine`` (same stack as standalone compactor-vllm:
``PagedKVCache`` + pruning kernels). If all prompts have ``compression_ratio >= 1.0``,
the batch stays on standard vLLM.
``compression_method`` follows :mod:`vllm.kvprune.core.compression_bridge` aliases:
``none``, ``criticaladakv``, ``compactor``, ``snapkv`` (ignored when
``compression_ratio`` is effectively 1).
``protected_*`` map to compactor :class:`~vllm.kvprune.compression.compression_config.SequenceCompressionParams`
(defaults match standalone compactor-vllm-style usage).
"""
compression_ratio
:
float
=
1.0
compression_method
:
str
=
"compactor"
protected_first_tokens
:
int
=
16
protected_last_tokens
:
int
=
64
def
__post_init__
(
self
)
->
None
:
if
not
0.0
<
self
.
compression_ratio
<=
1.0
:
raise
ValueError
(
f
"compression_ratio must be in (0, 1], got
{
self
.
compression_ratio
}
"
)
self
.
compression_method
=
(
self
.
compression_method
or
"compactor"
).
strip
().
lower
()
from
vllm.kvprune.core.compression_bridge
import
VALID_ALIASES_FOR_SAMPLING
if
self
.
compression_method
not
in
VALID_ALIASES_FOR_SAMPLING
:
raise
ValueError
(
f
"compression_method must be one of
{
sorted
(
VALID_ALIASES_FOR_SAMPLING
)
}
, "
f
"got
{
self
.
compression_method
!
r
}
"
)
if
self
.
compression_ratio
>=
1.0
-
1e-9
:
self
.
compression_method
=
"none"
elif
self
.
compression_method
==
"none"
:
raise
ValueError
(
"When compression_ratio < 1.0, compression_method cannot be 'none'."
)
vllm/kvprune/integration/config_adapter.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Build :class:`vllm.kvprune.config.engine_config.LLMConfig` from :class:`VllmConfig`."""
from
__future__
import
annotations
import
os
from
pathlib
import
Path
from
vllm.config
import
VllmConfig
from
vllm.kvprune.config.engine_config
import
LLMConfig
,
KvpruneAttentionSchedule
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
_attention_schedule_from_env
()
->
KvpruneAttentionSchedule
:
"""Resolve :class:`KvpruneAttentionSchedule` from env.
Primary (``VLLM_KVPRUNE_ATTENTION_SCHEDULE``):
- ``fa_triton`` — FA prefill, Triton decode (default). Aliases: ``fa_prefill``,
``default``, empty.
- ``pdtriton`` — Triton prefill + Triton decode. Aliases: ``triton``,
``triton_prefill``, ``compactor_prefill``, ``pd_triton``.
- ``pdfa`` — FA prefill + FA decode (KV stores still Triton). Aliases:
``fa_full``, ``fa_both``.
Legacy: ``VLLM_KVPRUNE_ATTENTION_BACKEND`` maps ``flash``/``fa`` → ``fa_triton``,
``compactor``/``triton`` → ``pdtriton``.
"""
s
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_ATTENTION_SCHEDULE"
,
""
).
strip
().
lower
()
if
s
in
(
"fa_triton"
,
"fa_prefill"
,
"default"
,
""
):
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
if
s
in
(
"pdtriton"
,
"pd_triton"
,
"triton"
,
"triton_prefill"
,
"compactor_prefill"
):
return
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
if
s
in
(
"pdfa"
,
"fa_full"
,
"fa_both"
):
return
KvpruneAttentionSchedule
.
PDFA
if
s
:
logger
.
warning
(
"Unknown VLLM_KVPRUNE_ATTENTION_SCHEDULE=%r; using FA_PREFILL_TRITON_DECODE"
,
s
,
)
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
v
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_ATTENTION_BACKEND"
,
""
).
strip
().
lower
()
if
v
in
(
"flash"
,
"fa"
,
"flash_attention"
,
"flashattention"
):
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
if
v
in
(
"compactor"
,
"triton"
,
"compactor_triton"
,
""
):
return
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
logger
.
warning
(
"Unknown VLLM_KVPRUNE_ATTENTION_BACKEND=%r; using FA_PREFILL_TRITON_DECODE"
,
v
)
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
def
_compactor_kvcache_page_size
(
vllm_block_size
:
int
|
None
)
->
int
:
"""Tokens per physical KV page for compactor :class:`LLMConfig`.
``compactor-vllm`` uses ``kvcache_page_size=128`` by default. Keeping that page
size is important for correctness comparisons when validating the integrated
``kvprune`` backend against standalone compactor, especially for ``pdtriton``
where paged-KV layout and page-padding behavior are part of the observed
divergence on DCU.
Override with ``VLLM_KVPRUNE_PAGE_SIZE``:
- unset: use standalone-compactor-compatible ``128``
- positive integer: use that exact page size (must be divisible by 32)
- ``vllm`` / ``inherit`` / ``block``: derive from vLLM ``block_size`` and round up
to the next multiple of 32 (the older integrated behavior)
"""
env_v
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_PAGE_SIZE"
,
""
).
strip
().
lower
()
if
env_v
:
if
env_v
in
(
"vllm"
,
"inherit"
,
"block"
):
bs
=
128
if
vllm_block_size
is
None
else
int
(
vllm_block_size
)
if
bs
<=
0
:
return
128
if
bs
%
32
==
0
:
return
bs
return
((
bs
+
31
)
//
32
)
*
32
page_size
=
int
(
env_v
)
if
page_size
<=
0
or
page_size
%
32
!=
0
:
raise
ValueError
(
"VLLM_KVPRUNE_PAGE_SIZE must be a positive multiple of 32, "
f
"got
{
page_size
}
."
)
return
page_size
return
128
def
vllm_config_to_llm_config
(
vc
:
VllmConfig
)
->
LLMConfig
:
"""Map vLLM engine config to compactor :class:`LLMConfig`."""
mc
=
vc
.
model_config
cc
=
vc
.
cache_config
pc
=
vc
.
parallel_config
sc
=
vc
.
scheduler_config
block_size
=
cc
.
block_size
if
block_size
is
None
:
block_size
=
16
max_num_seqs
=
getattr
(
sc
,
"max_num_seqs"
,
256
)
# Do **not** forward ``model_config.enforce_eager`` (v1) into compactor
# :class:`LLMConfig`. They are independent flags: v1 uses it only to skip
# *v1* ``capture_model()``; kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner`
# uses :attr:`LLMConfig.enforce_eager` only for *compactor* decode CUDA graphs.
# Shared-weight setup in ``compactor_shared`` defaults compactor to eager decode;
# see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` (default try graphs) /
# ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``.
# Local checkpoint directory: forward so compactor skips redundant Hub fetches.
_model_s
=
str
(
mc
.
model
)
_path
:
str
|
None
=
None
try
:
if
_model_s
and
Path
(
_model_s
).
is_dir
()
and
(
Path
(
_model_s
)
/
"config.json"
).
is_file
():
_path
=
str
(
Path
(
_model_s
).
resolve
())
except
OSError
:
pass
page_size
=
_compactor_kvcache_page_size
(
block_size
)
attention_schedule
=
_attention_schedule_from_env
()
logger
.
info
(
"kvprune compactor config: attention_schedule=%s, kvcache_page_size=%d"
,
attention_schedule
.
name
,
page_size
,
)
return
LLMConfig
(
model
=
_model_s
,
path
=
_path
,
nccl_port
=
1218
,
max_num_seqs
=
max_num_seqs
,
max_model_len
=
mc
.
max_model_len
,
gpu_memory_utilization
=
cc
.
gpu_memory_utilization
,
tensor_parallel_size
=
pc
.
tensor_parallel_size
,
enforce_eager
=
False
,
hf_config
=
mc
.
hf_config
,
eos
=-
1
,
eos_token_ids
=
None
,
kvcache_page_size
=
page_size
,
leverage_sketch_size
=
48
,
attention_schedule
=
attention_schedule
,
attention_backend
=
None
,
)
vllm/kvprune/integration/v1_tp_runner.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TP>1: one kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner` per vLLM worker.
Invoked via v1 ``collective_rpc("kvprune_v1_compressed_generate", ...)`` so every tensor-
parallel rank participates in the same compactor forward/broadcast sequence as the
standalone multi-process compactor.
Compactor decode CUDA graphs (when not ``enforce_eager``) capture the full decode step
including ``compute_logits``. To force eager on embedded TP workers, set
``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` or ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``.
Peer/master session boundaries use TP-group ``broadcast``/``barrier`` (see
``ModelRunner.maybe_release_peers``), not ``multiprocessing.Event`` — RPC payloads must
be picklable across worker processes.
"""
from
__future__
import
annotations
import
os
from
typing
import
Any
import
torch
import
torch.nn
as
nn
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
SequenceCompressionParams
,
)
from
vllm.kvprune.config.sampling_params
import
SamplingParams
as
CompactorSamplingParams
from
vllm.kvprune.core.compression_bridge
import
(
compression_method_id_to_enum
,
compression_method_str_to_id
,
)
from
vllm.kvprune.core.model_runner
import
ModelRunner
from
vllm.kvprune.integration.config_adapter
import
vllm_config_to_llm_config
from
vllm.kvprune.utils.kv_dist
import
barrier_sync
from
vllm.kvprune.integration.weight_tie
import
(
delegate_kvprune_compute_logits_to_vllm
,
delegate_kvprune_embed_tokens_to_vllm
,
tie_kvprune_rope_buffers_from_vllm
,
tie_kvprune_weights_from_vllm
,
)
from
vllm.kvprune.models
import
MODEL_REGISTRY
from
vllm.kvprune.utils.sequence
import
Sequence
_ATTR
=
"_kvprune_tp_embedded_runner"
def
_apply_compactor_env_overrides
(
cfg
:
Any
)
->
None
:
"""Match :func:`~vllm.kvprune.integration.compactor_shared.create_compactor_engine_with_shared_weights` caps."""
_cap
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS"
,
"32"
).
strip
()
if
_cap
:
lim
=
int
(
_cap
)
if
lim
>
0
:
cfg
.
max_num_seqs
=
min
(
cfg
.
max_num_seqs
,
lim
)
_ce
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER"
,
""
).
strip
().
lower
()
if
_ce
in
(
"1"
,
"true"
,
"yes"
):
cfg
.
enforce_eager
=
True
elif
_ce
in
(
"0"
,
"false"
,
"no"
):
cfg
.
enforce_eager
=
False
else
:
_dg
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH"
,
"1"
).
strip
().
lower
()
cfg
.
enforce_eager
=
_dg
in
(
"0"
,
"false"
,
"no"
)
def
_build_sequences
(
payload
:
dict
[
str
,
Any
])
->
list
[
Sequence
]:
prompt_ids
:
list
[
list
[
int
]]
=
payload
[
"prompt_token_ids"
]
sps
:
list
[
dict
[
str
,
Any
]]
=
payload
[
"sampling_params"
]
cps
:
list
[
dict
[
str
,
Any
]]
=
payload
[
"compression_params"
]
seqs
:
list
[
Sequence
]
=
[]
for
i
,
ids
in
enumerate
(
prompt_ids
):
sp
=
CompactorSamplingParams
(
temperature
=
float
(
sps
[
i
][
"temperature"
]),
max_new_tokens
=
int
(
sps
[
i
][
"max_new_tokens"
]),
)
cp
=
SequenceCompressionParams
(
compression_ratio
=
float
(
cps
[
i
][
"compression_ratio"
]),
protected_first_tokens
=
int
(
cps
[
i
].
get
(
"protected_first_tokens"
,
16
)),
protected_last_tokens
=
int
(
cps
[
i
].
get
(
"protected_last_tokens"
,
64
)),
)
if
cp
.
protected_first_tokens
+
cp
.
protected_last_tokens
>=
len
(
ids
):
cp
.
compression_ratio
=
1.0
seqs
.
append
(
Sequence
(
prompt_token_ids
=
list
(
ids
),
sampling_params
=
sp
,
compression_params
=
cp
,
)
)
return
seqs
def
_batch_compression_from_payload
(
payload
:
dict
[
str
,
Any
])
->
BatchCompressionParams
:
cps
=
payload
[
"compression_params"
]
for
c
in
cps
:
if
float
(
c
[
"compression_ratio"
])
<
1.0
:
mid
=
compression_method_str_to_id
(
str
(
c
.
get
(
"compression_method"
,
"none"
)))
return
BatchCompressionParams
(
compression_method
=
compression_method_id_to_enum
(
mid
)
)
return
BatchCompressionParams
()
def
_get_or_create_runner
(
worker
:
Any
,
payload
:
dict
[
str
,
Any
])
->
ModelRunner
:
existing
=
getattr
(
worker
,
_ATTR
,
None
)
if
existing
is
not
None
:
return
existing
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
vc
=
worker
.
vllm_config
pc
=
vc
.
parallel_config
if
pc
.
pipeline_parallel_size
!=
1
or
pc
.
data_parallel_size
!=
1
:
raise
NotImplementedError
(
"KV-prune TP compressed generate requires pipeline_parallel_size=1 and "
f
"data_parallel_size=1; got PP=
{
pc
.
pipeline_parallel_size
}
, "
f
"DP=
{
pc
.
data_parallel_size
}
."
)
tp_ws
=
get_tensor_model_parallel_world_size
()
if
tp_ws
!=
pc
.
tensor_parallel_size
:
raise
RuntimeError
(
f
"parallel_state TP world size
{
tp_ws
}
!= config.tensor_parallel_size "
f
"
{
pc
.
tensor_parallel_size
}
"
)
hf
=
vc
.
model_config
.
hf_config
model_type
=
getattr
(
hf
,
"model_type"
,
None
)
if
model_type
not
in
MODEL_REGISTRY
:
raise
ValueError
(
f
"KV-prune TP path: unsupported model_type=
{
model_type
!
r
}
; "
f
"registry has
{
sorted
(
MODEL_REGISTRY
)
}
"
)
cfg
=
vllm_config_to_llm_config
(
vc
)
eos_ids
=
payload
[
"eos_token_ids"
]
cfg
.
eos_token_ids
=
sorted
({
int
(
x
)
for
x
in
eos_ids
})
cfg
.
eos
=
int
(
cfg
.
eos_token_ids
[
0
])
_apply_compactor_env_overrides
(
cfg
)
vllm_model
=
worker
.
get_model
()
kv_model
:
nn
.
Module
=
MODEL_REGISTRY
[
model_type
](
hf
)
tie_kvprune_weights_from_vllm
(
vllm_model
,
kv_model
)
dev
=
next
(
vllm_model
.
parameters
()).
device
dtype
=
next
(
vllm_model
.
parameters
()).
dtype
kv_model
.
to
(
device
=
dev
,
dtype
=
dtype
)
tie_kvprune_rope_buffers_from_vllm
(
vllm_model
,
kv_model
)
delegate_kvprune_embed_tokens_to_vllm
(
vllm_model
,
kv_model
)
delegate_kvprune_compute_logits_to_vllm
(
vllm_model
,
kv_model
)
tp_rank
=
get_tensor_model_parallel_rank
()
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
if
tp_rank
==
0
:
runner
=
ModelRunner
(
cfg
,
rank
=
0
,
peer_events
=
[],
external_model
=
kv_model
,
embedded_in_vllm_worker
=
True
,
device
=
device
,
)
else
:
runner
=
ModelRunner
(
cfg
,
rank
=
tp_rank
,
batch_ready
=
None
,
external_model
=
kv_model
,
embedded_in_vllm_worker
=
True
,
device
=
device
,
)
setattr
(
worker
,
_ATTR
,
runner
)
return
runner
def
run_kvprune_tp_compressed_generate
(
worker
:
Any
,
payload
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Execute one compressed generation session on this worker (all TP ranks)."""
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
tp_rank
=
get_tensor_model_parallel_rank
()
runner
=
_get_or_create_runner
(
worker
,
payload
)
sequences
=
_build_sequences
(
payload
)
batch_c
=
_batch_compression_from_payload
(
payload
)
barrier_sync
(
use_tp_group
=
True
)
if
tp_rank
==
0
:
runner
.
generate
(
sequences
,
batch_c
)
return
{
"tensor_parallel_rank"
:
0
,
"prompt_token_ids"
:
[
list
(
s
.
prompt_token_ids
)
for
s
in
sequences
],
"completion_token_ids"
:
[
list
(
s
.
completion_token_ids
)
for
s
in
sequences
],
}
runner
.
run_peer_session
()
return
{
"tensor_parallel_rank"
:
int
(
tp_rank
),
"ok"
:
True
}
vllm/kvprune/integration/vllm_model_access.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Access the in-process vLLM model weights for compactor weight sharing."""
from
__future__
import
annotations
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
extract_vllm_causal_lm
(
llm
:
object
)
->
nn
.
Module
:
"""Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``.
Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False``
so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
"""
llm_engine
=
getattr
(
llm
,
"llm_engine"
,
None
)
if
llm_engine
is
None
:
raise
RuntimeError
(
"Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``)."
)
ex
=
getattr
(
llm_engine
,
"model_executor"
,
None
)
if
ex
is
None
:
raise
RuntimeError
(
"model_executor is unavailable (multiprocess engine mode). "
"Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for "
"in-process weight sharing."
)
driver
=
getattr
(
ex
,
"driver_worker"
,
None
)
if
driver
is
None
:
raise
RuntimeError
(
"Executor has no driver_worker (unexpected executor type for weight sharing)."
)
worker
=
getattr
(
driver
,
"worker"
,
None
)
if
worker
is
None
:
raise
RuntimeError
(
"Worker wrapper has no worker loaded."
)
get_model
=
getattr
(
worker
,
"get_model"
,
None
)
if
not
callable
(
get_model
):
raise
RuntimeError
(
"Worker does not expose get_model()."
)
return
get_model
()
vllm/kvprune/integration/weight_tie.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Share vLLM parameter storage with compactor ``MODEL_REGISTRY`` models (TP=1)."""
from
__future__
import
annotations
import
types
import
torch
import
torch.nn
as
nn
from
vllm.kvprune.utils.context
import
get_context
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
tie_kvprune_weights_from_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
*
,
strict
:
bool
=
True
,
)
->
int
:
"""Point compactor parameters to the same tensors as vLLM where names match.
Returns the number of parameters tied. Requires identical parameter names
and shapes for overlapping weights (typical when both stacks mirror HF
naming for the same architecture).
Args:
vllm_model: Model returned by ``worker.get_model()`` (e.g. ``Qwen3ForCausalLM``).
kvprune_model: Instance from ``vllm.kvprune.models.MODEL_REGISTRY``.
strict: If True, raise when any ``kvprune`` parameter name is missing from
``vllm_model`` or shapes differ.
"""
vd
=
dict
(
vllm_model
.
named_parameters
())
kd
=
dict
(
kvprune_model
.
named_parameters
())
tied
=
0
for
name
,
kp
in
kd
.
items
():
if
name
not
in
vd
:
if
strict
:
raise
ValueError
(
f
"kvprune parameter
{
name
!
r
}
not found in vLLM model; "
"architecture/layout may differ (disable strict tying only "
"for expert debugging)."
)
continue
vp
=
vd
[
name
]
if
vp
.
shape
!=
kp
.
shape
:
raise
ValueError
(
f
"Shape mismatch for
{
name
}
: vllm
{
vp
.
shape
}
vs kvprune
{
kp
.
shape
}
"
)
kp
.
data
=
vp
.
data
tied
+=
1
if
tied
==
0
:
raise
ValueError
(
"No parameters were tied — check that vLLM and kvprune model types match "
"and use the same state_dict names."
)
logger
.
info
(
"Tied %d parameters from vLLM into compactor model (shared storage)."
,
tied
)
return
tied
def
tie_kvprune_rope_buffers_from_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
)
->
int
:
"""Copy RoPE ``cos_sin_cache`` buffers from vLLM into kvprune.
:func:`tie_kvprune_weights_from_vllm` only aliases :class:`~torch.nn.Parameter`
tensors. RoPE tables live in buffers; kvprune's simplified ``RotaryEmbedding``
can disagree with vLLM's ``rope_parameters`` (YaRN, etc.). Copying
``cos_sin_cache`` after ``.to(device, dtype)`` keeps Q/K rotation aligned with
the main model.
kvprune uses layout ``[max_len, 1, rotary_dim]``; vLLM uses ``[max_len,
rotary_dim]``. The singleton dim is filled via ``unsqueeze(1)`` on the vLLM
tensor when copying.
"""
vd
=
dict
(
vllm_model
.
named_buffers
())
copied
=
0
for
name
,
kb
in
kvprune_model
.
named_buffers
():
if
"cos_sin_cache"
not
in
name
:
continue
if
name
not
in
vd
:
logger
.
warning
(
"kvprune RoPE buffer %r not found in vLLM; leaving kvprune cache"
,
name
,
)
continue
vb
=
vd
[
name
]
if
vb
.
shape
==
kb
.
shape
:
kb
.
copy_
(
vb
)
copied
+=
1
elif
kb
.
dim
()
==
3
and
vb
.
dim
()
==
2
:
if
(
kb
.
shape
[
0
]
!=
vb
.
shape
[
0
]
or
kb
.
shape
[
2
]
!=
vb
.
shape
[
1
]
or
kb
.
shape
[
1
]
!=
1
):
raise
ValueError
(
f
"cos_sin_cache shape mismatch for
{
name
!
r
}
: "
f
"vLLM
{
tuple
(
vb
.
shape
)
}
vs kvprune
{
tuple
(
kb
.
shape
)
}
"
)
kb
.
copy_
(
vb
.
unsqueeze
(
1
))
copied
+=
1
else
:
raise
ValueError
(
f
"Unsupported cos_sin_cache layout for
{
name
!
r
}
: "
f
"vLLM
{
tuple
(
vb
.
shape
)
}
vs kvprune
{
tuple
(
kb
.
shape
)
}
"
)
if
copied
:
logger
.
info
(
"Copied %d RoPE cos_sin_cache buffer(s) from vLLM into kvprune model."
,
copied
,
)
return
copied
def
delegate_kvprune_embed_tokens_to_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
)
->
bool
:
"""Use vLLM's ``model.embed_tokens`` forward for kvprune (TP-safe token→shard mapping).
Even with tied weights, kvprune's simplified contiguous
``VocabParallelEmbedding`` (``vocab_start = rank * partition``) can disagree with
vLLM's padded vocabulary and org/added shard ranges, producing invalid indices for
``F.embedding`` on non-zero TP ranks (``index_copy_`` / device-side assert).
Delegating the forward to vLLM's embedding module keeps masks and indices aligned
with the main model while parameters remain shared storage.
"""
if
not
hasattr
(
vllm_model
,
"model"
)
or
not
hasattr
(
kvprune_model
,
"model"
):
return
False
vm
=
getattr
(
vllm_model
.
model
,
"embed_tokens"
,
None
)
km
=
getattr
(
kvprune_model
.
model
,
"embed_tokens"
,
None
)
if
vm
is
None
or
km
is
None
:
logger
.
warning
(
"delegate_kvprune_embed_tokens_to_vllm: embed_tokens missing; skipped"
)
return
False
def
_forward
(
_self_unused
:
nn
.
Module
,
x
):
return
vm
(
x
)
km
.
forward
=
types
.
MethodType
(
_forward
,
km
)
logger
.
info
(
"kvprune model.embed_tokens forward delegated to vLLM (correct vocab-parallel masks)."
)
return
True
def
delegate_kvprune_compute_logits_to_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
)
->
bool
:
"""Route ``kvprune_model.compute_logits`` through vLLM's ``compute_logits``.
Standalone compactor used :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`
with ``F.linear`` + TP gather. vLLM applies :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
(gather/all-gather, padded-vocab trim, quant hooks). Mismatch here commonly
produces garbage token distributions while the rest of the stack looks fine.
After weight tying, ``vllm_model.compute_logits(hidden)`` uses the same lm_head
storage as kvprune; only the *application* path matches production vLLM.
"""
if
not
callable
(
getattr
(
vllm_model
,
"compute_logits"
,
None
)):
logger
.
warning
(
"delegate_kvprune_compute_logits_to_vllm: vLLM model has no compute_logits; skipped"
)
return
False
def
_compute_logits
(
_self
:
nn
.
Module
,
hidden_states
):
# Match kvprune :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`:
# prefill logits are for the **last** token of each packed sequence only.
context
=
get_context
()
if
context
.
is_prefill
and
context
.
cu_seqlens_q
is
not
None
:
cuq
=
context
.
cu_seqlens_q
last_indices
=
(
cuq
[
1
:]
-
1
).
to
(
torch
.
long
)
n_tok
=
hidden_states
.
shape
[
0
]
if
n_tok
>
0
:
last_indices
=
last_indices
.
clamp
(
min
=
0
,
max
=
n_tok
-
1
)
hidden_states
=
hidden_states
[
last_indices
].
contiguous
()
# vLLM lm_head + gather expect contiguous activations; non-contiguous views have
# caused garbage logits under TP in edge cases.
hidden_states
=
hidden_states
.
contiguous
()
logits
=
vllm_model
.
compute_logits
(
hidden_states
)
return
logits
kvprune_model
.
compute_logits
=
types
.
MethodType
(
_compute_logits
,
kvprune_model
)
return
True
vllm/kvprune/kv_cache/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Paged KV cache helpers and Triton KV store."""
from
vllm.kvprune.kv_cache.store_kv_cache
import
(
decode_store_kv
,
prefill_store_all_kv
,
prefill_store_topk_kv
,
)
__all__
=
[
"decode_store_kv"
,
"prefill_store_all_kv"
,
"prefill_store_topk_kv"
,
]
vllm/kvprune/kv_cache/page_table.py
0 → 100644
View file @
2b7160c6
import
heapq
import
logging
from
enum
import
Enum
,
auto
from
typing
import
List
,
Optional
,
Union
import
torch
from
vllm.kvprune.config.constants
import
RESERVED_BATCH
from
vllm.kvprune.kv_cache.write_page_table
import
scatter_to_page_table
logger
=
logging
.
getLogger
(
__name__
)
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
def
next_multiple
(
a
,
b
):
return
cdiv
(
a
,
b
)
*
b
class
KVAllocationStatus
(
Enum
):
EXCEEDS_MAX_SEQUENCE_LENGTH
=
auto
()
EXCEEDS_CURRENTLY_AVAILABLE_PAGES
=
auto
()
EXCEEDS_MAX_NUM_BATCHES
=
auto
()
SUCCESS
=
auto
()
class
PagedKVCache
(
torch
.
nn
.
Module
):
"""
Global paged KV cache.
This module manages:
* A global K/V backing buffer for all layers:
``kv_cache[2, num_layers, n_pages * page_size, head_dim]``,
where the first dimension indexes K vs V.
* A per-layer page table:
``page_table[num_layers, max_num_seqs, H_kv, max_pages_per_head]``,
mapping logical (batch, kv-head, logical_page) to a physical page ID
in the global K/V buffer.
* Per-layer, per-(batch, kv-head) logical sequence lengths
``bh_seq_lens[num_layers, max_num_seqs, H_kv]`` (in tokens), and
the number of allocated pages ``bh_num_pages`` for each (layer, batch,
head).
* A page allocator implemented as a min-heap of free physical pages
per layer, plus free batch indices.
Pages are of fixed size ``page_size`` tokens.
Args:
:param num_layers:
Number of transformer layers that will use this cache.
:param max_logical_pages_per_head:
Maximum number of logical pages that can be assigned to a single
(batch, kv-head) pair.
:param num_pages:
Total number of physical pages available in the global cache per
layer. The global K/V buffers are of length
``num_pages * page_size`` along the token dimension.
:param page_size:
Number of tokens stored per page.
:param H_kv:
Number of KV heads per layer.
:param head_dim:
Head dimension for K/V.
:param max_num_batches:
Maximum number of concurrent batches / sequences supported. One
batch index is reserved for internal use (``RESERVED_BATCH``).
:param dtype:
Data type of K/V entries (e.g. ``torch.float16`` or ``torch.bfloat16``).
:param device:
Device on which to allocate the cache (string, torch.device, or
int; defaults to ``"cuda"``).
"""
def
__init__
(
self
,
num_layers
:
int
,
max_logical_pages_per_head
:
int
,
num_pages
:
int
,
page_size
:
int
,
# tokens per page
H_kv
:
int
,
head_dim
:
int
,
max_num_batches
:
int
,
dtype
:
torch
.
dtype
,
device
:
Union
[
str
,
torch
.
device
,
int
]
=
"cuda"
,
):
super
().
__init__
()
self
.
n_pages
=
num_pages
self
.
num_layers
=
num_layers
self
.
page_size
:
int
=
int
(
page_size
)
self
.
H_kv
=
int
(
H_kv
)
self
.
max_pages_per_head
=
max_logical_pages_per_head
max_num_batches
+=
1
self
.
max_num_batches
=
max_num_batches
self
.
head_dim
=
head_dim
cache_shape
=
(
2
,
num_layers
,
num_pages
*
page_size
,
head_dim
)
self
.
kv_cache
=
torch
.
empty
(
cache_shape
,
dtype
=
dtype
,
device
=
device
)
self
.
page_table
=
torch
.
empty
(
(
num_layers
,
max_num_batches
,
H_kv
,
self
.
max_pages_per_head
),
device
=
device
,
dtype
=
torch
.
int32
,
)
# Per-(batch, head) logical seq length (tokens)
self
.
bh_seq_lens
=
torch
.
zeros
(
(
num_layers
,
max_num_batches
,
H_kv
),
device
=
device
,
dtype
=
torch
.
int32
)
# self._bh_seq_lens_cpu_buffer = torch.zeros((num_layers, H_kv), device="cpu", dtype=torch.int32)
self
.
bh_num_pages
=
torch
.
zeros
(
(
num_layers
,
max_num_batches
,
H_kv
),
device
=
device
,
dtype
=
torch
.
int32
)
# Page allocator (min-heap of free physical pages)
self
.
free_pages
:
List
[
List
[
int
]]
=
[
list
(
range
(
num_pages
))
for
_
in
range
(
num_layers
)
]
for
free_pages
in
self
.
free_pages
:
heapq
.
heapify
(
free_pages
)
# batch zero is reserved
self
.
free_batches
:
List
[
int
]
=
list
(
reversed
(
range
(
max_num_batches
)))
self
.
free_batches
.
remove
(
RESERVED_BATCH
)
# Record of physical page ids owned by a batch (for freeing)
self
.
pages_indices_per_batch
:
List
[
List
[
set
[
int
]]]
=
[
[
set
()
for
_
in
range
(
num_layers
)]
for
_
in
range
(
max_num_batches
)
]
def
new_batch
(
self
)
->
Optional
[
int
]:
"""
Reserve a new batch slot.
A batch slot corresponds to a row in ``bh_seq_lens`` /
``bh_num_pages`` and a slice in ``page_table`` for all layers and KV
heads. This method checks whether a free batch index is available, and
whether each layer has at least ``H_kv`` free pages remaining.
If both checks pass, it returns a batch index and removes it from
``free_batches``. Otherwise, it returns ``None``.
Returns:
:return Optional[int]:
Newly reserved batch index, or ``None`` if no capacity is
available.
"""
if
self
.
free_batches
and
all
([
self
.
H_kv
<=
len
(
fp
)
for
fp
in
self
.
free_pages
]):
return
self
.
free_batches
.
pop
()
return
None
def
reserve_tokens
(
self
,
batch_index
:
int
,
add_tokens
:
int
)
->
KVAllocationStatus
:
"""
Ensure enough pages are allocated to handle ``add_tokens`` new tokens.
Args:
:param batch_index:
Batch index to reserve space for.
:param add_tokens:
Number of additional tokens to reserve capacity for.
All heads in this batch and all layers reserve
the same number of extra tokens.
Returns:
:return bool:
``True`` if the reservation succeeds; ``False`` otherwise .
"""
cur_bh_lens
=
self
.
bh_seq_lens
[:,
batch_index
]
# [L, H]
curr_pages
=
self
.
bh_num_pages
[:,
batch_index
]
# [L, H]
curr_cap_tokens
=
curr_pages
*
self
.
page_size
# [L, H]
need_tokens
=
cur_bh_lens
+
add_tokens
# [L, H]
if
(
need_tokens
<=
curr_cap_tokens
).
all
():
return
KVAllocationStatus
.
SUCCESS
missing_tokens
=
need_tokens
-
curr_cap_tokens
add_pages
=
cdiv
(
missing_tokens
,
self
.
page_size
)
new_total_pages
=
curr_pages
+
add_pages
if
(
new_total_pages
>
self
.
max_pages_per_head
).
any
():
return
KVAllocationStatus
.
EXCEEDS_MAX_SEQUENCE_LENGTH
# CPU work
pages_per_layer_cpu
=
add_pages
.
sum
(
dim
=-
1
).
tolist
()
new_phys_pages
=
[]
for
layer_index
in
range
(
self
.
num_layers
):
if
pages_per_layer_cpu
[
layer_index
]
>
len
(
self
.
free_pages
[
layer_index
]):
return
KVAllocationStatus
.
EXCEEDS_CURRENTLY_AVAILABLE_PAGES
for
layer_index
in
range
(
self
.
num_layers
):
this_layer_pages
=
[
heapq
.
heappop
(
self
.
free_pages
[
layer_index
])
for
_
in
range
(
pages_per_layer_cpu
[
layer_index
])
]
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]
|=
set
(
this_layer_pages
)
new_phys_pages
.
extend
(
this_layer_pages
)
new_phys_pages
=
torch
.
tensor
(
new_phys_pages
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
scatter_to_page_table
(
add_pages
=
add_pages
,
new_phys_pages
=
new_phys_pages
,
curr_pages
=
curr_pages
,
page_table
=
self
.
page_table
[:,
batch_index
],
max_pages_per_head
=
self
.
max_pages_per_head
,
)
self
.
bh_num_pages
[:,
batch_index
,
:]
=
new_total_pages
.
to
(
self
.
bh_num_pages
.
dtype
)
return
KVAllocationStatus
.
SUCCESS
def
reclaim_pages
(
self
,
batch_index
:
int
,
future_reserve_tokens
:
int
=
0
,
):
"""
Reclaim unused pages for a single batch index. This shrinks the KV
allocation for the batch down to the minimum number of pages needed
to hold the current (plus optional future) sequence length.
Args:
:param batch_index:
Batch index whose pages should be compacted.
:param future_reserve_tokens:
Optional number of extra tokens to keep capacity for, beyond
the current sequence length. This can reduce churn when
sequences are expected to grow slightly in the near future.
Returns:
:return int:
Approximate number of bytes freed across both K and V.
"""
device
=
self
.
bh_seq_lens
.
device
L
,
B
,
H
=
self
.
bh_seq_lens
.
shape
assert
0
<=
batch_index
<
B
seq
=
self
.
bh_seq_lens
[:,
batch_index
,
:]
+
future_reserve_tokens
# [L, H]
alloc
=
self
.
bh_num_pages
[:,
batch_index
,
:]
# [L, H]
pt
=
self
.
page_table
[:,
batch_index
,
:,
:].
reshape
(
-
1
)
# [L, H, P]
# Compute used pages: ceil_div(seq, page_size), clamped into [0, alloc]
used_pages
=
cdiv
(
seq
,
self
.
page_size
)
used_pages
=
torch
.
minimum
(
used_pages
,
alloc
)
# page indices [0..P-1], broadcasted over [L, H, P]
p
=
torch
.
arange
(
self
.
max_pages_per_head
,
device
=
device
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
self
.
max_pages_per_head
)
# allocated: p < alloc
alloc_mask
=
p
<
alloc
.
unsqueeze
(
-
1
)
# [L, H, P]
# to free: allocated and p in [used_pages, alloc)
free_mask
=
alloc_mask
&
(
p
>=
used_pages
.
unsqueeze
(
-
1
))
free_mask_flat
=
free_mask
.
view
(
-
1
)
# [L*H*P]
if
not
free_mask_flat
.
any
():
return
0
idx
=
free_mask_flat
.
nonzero
(
as_tuple
=
False
).
squeeze
(
-
1
)
# indices of freed slots
# Freed physical page ids
freed_pages
=
pt
[
idx
]
# Compute layer index for each freed slot:
# layout is [L, H, P] 鈫?flat index = ((l * H) + h) * P + p
freed_layers
=
(
idx
//
(
H
*
self
.
max_pages_per_head
)).
to
(
torch
.
int32
)
freed_pages
=
freed_pages
.
tolist
()
layer_mapping
=
freed_layers
.
tolist
()
self
.
bh_num_pages
[:,
batch_index
,
:]
=
used_pages
for
page
,
layer
in
zip
(
freed_pages
,
layer_mapping
):
self
.
pages_indices_per_batch
[
batch_index
][
layer
].
remove
(
page
)
heapq
.
heappush
(
self
.
free_pages
[
layer
],
page
)
approximate_bytes_freed
=
(
len
(
freed_pages
)
*
(
self
.
page_size
*
self
.
head_dim
*
self
.
kv_cache
.
element_size
())
*
2
)
# multiply for two for K + V
return
approximate_bytes_freed
def
_free_batch_layer
(
self
,
layer_index
:
int
,
batch_index
:
int
)
->
None
:
"""
Free all pages belonging to batch_index and reset its metadata.
"""
# Return pages to the global heap
for
phys
in
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]:
heapq
.
heappush
(
self
.
free_pages
[
layer_index
],
int
(
phys
))
self
.
pages_indices_per_batch
[
batch_index
][
layer_index
]
=
set
()
def
free_batch
(
self
,
batch_index
:
int
)
->
None
:
"""
Free all resources associated with a batch index.
Args:
:param batch_index:
Batch index to release. Must have been previously allocated
via :meth:`new_batch`.
"""
for
layer
in
range
(
self
.
num_layers
):
self
.
_free_batch_layer
(
layer
,
batch_index
)
self
.
bh_seq_lens
[:,
batch_index
].
zero_
()
self
.
bh_num_pages
[:,
batch_index
].
zero_
()
self
.
free_batches
.
append
(
batch_index
)
def
layer_slices
(
self
,
layer
:
int
):
"""
Return layer-local views needed by the attention module.
For a given ``layer`` index, this method returns the slices of the
global K/V cache, page table, and per-(batch, head) sequence lengths
corresponding to that layer.
Args:
:param layer:
Layer index ``l`` in ``[0, num_layers)``.
Returns:
:return Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
``(k, v, pt, bh)`` as described above.
"""
assert
0
<=
layer
<
self
.
num_layers
k
=
self
.
kv_cache
[
0
,
layer
]
v
=
self
.
kv_cache
[
1
,
layer
]
pt
=
self
.
page_table
[
layer
]
bh
=
self
.
bh_seq_lens
[
layer
]
return
k
,
v
,
pt
,
bh
vllm/kvprune/kv_cache/store_kv_cache.py
0 → 100644
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
from
vllm.kvprune.config.constants
import
(
TRITON_RESERVED_BATCH
as
_TRITON_RESERVED_BATCH
,
)
@
triton
.
jit
def
_prefill_store_topk_kv_kernel
(
key
,
value
,
# [N_total, H, D] (D stride assumed 1)
batch_mapping
,
# [B] int32 (local b -> true batch)
num_tokens_to_retain
,
# [B] int32
indices_topk
,
# [B, MAX_SEL] int32 (across all heads)
# Lengths & page table:
bh_lens
,
# [B, H] int32 (contiguous)
page_table
,
# [B_total * H * N_LOGICAL_PAGES_MAX] int32 (flattened), read-only
k_cache
,
v_cache
,
# [N_PAGES * PAGE_SIZE, D]
sk_n
,
sk_h
,
# strides for key,value. D stride assumed 1
sv_n
,
sv_h
,
# Runtime ints
MAX_SEL
,
# num tokens that are ranked in indices for each batch (might be bigger than num_tokens_to_retain)
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
K_TILE
:
tl
.
constexpr
,
# how many selected tokens each program processes
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
b_local
=
tl
.
program_id
(
0
)
tile_id
=
tl
.
program_id
(
1
)
offs
=
tl
.
arange
(
0
,
D
)
# how many tokens we actually keep for this batch
k_total
=
tl
.
load
(
num_tokens_to_retain
+
b_local
)
if
k_total
==
0
:
return
# map to true batch row in the page table
b_true
=
tl
.
load
(
batch_mapping
+
b_local
)
if
b_true
==
TRITON_RESERVED_BATCH
:
return
base
=
tile_id
*
K_TILE
# process up to K_TILE tokens
for
j
in
tl
.
range
(
0
,
K_TILE
):
sel_idx
=
base
+
j
if
sel_idx
<
k_total
and
sel_idx
<
MAX_SEL
:
# flattened selection: sel = token * H + head
sel
=
tl
.
load
(
indices_topk
+
b_local
*
MAX_SEL
+
sel_idx
)
tok
=
sel
//
HKV
head
=
sel
-
(
tok
*
HKV
)
# atomically reserve one position in (b_local, hed)
# i.e the KV cache is scrambled when storing
len_ptr
=
bh_lens
+
b_local
*
HKV
+
head
pos
=
tl
.
atomic_add
(
len_ptr
,
1
)
# old length (int32)
lp
=
pos
//
PAGE_SIZE
off
=
pos
-
lp
*
PAGE_SIZE
# translate logical page to physical page
pt_base
=
(
b_true
*
HKV
+
head
)
*
N_LOGICAL_PAGES_MAX
phys
=
tl
.
load
(
page_table
+
pt_base
+
lp
).
to
(
tl
.
int64
)
# destination row and element offset
dst_row
=
phys
*
PAGE_SIZE
+
off
dst_off
=
dst_row
*
D
+
offs
# load one vector from [N_total, H, D]
k_src
=
key
+
tok
*
sk_n
+
head
*
sk_h
+
offs
v_src
=
value
+
tok
*
sv_n
+
head
*
sv_h
+
offs
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
,
cache_modifier
=
".cv"
,
eviction_policy
=
"evict_first"
),
eviction_policy
=
"evict_first"
,
)
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
,
cache_modifier
=
".cv"
,
eviction_policy
=
"evict_first"
),
eviction_policy
=
"evict_first"
,
)
def
prefill_store_topk_kv
(
*
,
new_keys
:
torch
.
Tensor
,
# [N_total, H, D]
new_vals
:
torch
.
Tensor
,
# [N_total, H, D]
indices_topk
:
torch
.
Tensor
,
# [B, MAX_SEL] int32 (global flattened token*H + head)
candidate_counts
:
torch
.
Tensor
,
# [B] int32, valid candidates in indices_topk
num_tokens_to_retain
:
torch
.
Tensor
,
# [B] int32
page_table
:
torch
.
Tensor
,
# [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local -> true batch rows)
bh_lens
:
torch
.
Tensor
,
# [B, H] int32 (contiguous), UPDATED atomically
k_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
v_cache
:
torch
.
Tensor
,
# [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE
:
int
,
PAD_TO_PAGE_SIZE
:
bool
=
True
,
cu_seqlens_k
:
torch
.
Tensor
|
None
=
None
,
K_TILE
:
int
=
16
,
TRITON_RESERVED_BATCH
:
int
=
None
,
):
assert
new_keys
.
shape
==
new_vals
.
shape
N_total
,
H
,
D
=
new_keys
.
shape
B
=
indices_topk
.
shape
[
0
]
assert
page_table
.
shape
[
1
]
==
H
assert
bh_lens
.
shape
==
(
B
,
H
)
assert
new_keys
.
device
==
k_cache
.
device
==
v_cache
.
device
assert
page_table
.
is_contiguous
(),
"page table must be contiguous."
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous."
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous."
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
assert
new_keys
.
stride
(
-
1
)
==
1
and
new_vals
.
stride
(
-
1
)
==
1
,
(
"new_keys/new_vals last dim must be contiguous."
)
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
page_table
=
page_table
.
to
(
torch
.
int32
)
bh_lens
=
bh_lens
.
to
(
torch
.
int32
)
batch_mapping
=
batch_mapping
.
to
(
torch
.
int32
)
indices_topk
=
indices_topk
.
to
(
torch
.
int32
)
candidate_counts
=
candidate_counts
.
to
(
torch
.
int32
)
num_tokens_to_retain
=
num_tokens_to_retain
.
to
(
torch
.
int32
)
# strides (elements) for [N_total, H, D]
sk_n
,
sk_h
,
_
=
new_keys
.
stride
()
sv_n
,
sv_h
,
_
=
new_vals
.
stride
()
# tile second grid dim
MAX_SEL
=
indices_topk
.
shape
[
-
1
]
N_TILES
=
(
MAX_SEL
+
K_TILE
-
1
)
//
K_TILE
grid
=
(
B
,
max
(
1
,
N_TILES
))
if
TRITON_RESERVED_BATCH
is
None
:
TRITON_RESERVED_BATCH
=
_TRITON_RESERVED_BATCH
_prefill_store_topk_kv_kernel
[
grid
](
key
=
new_keys
,
value
=
new_vals
,
batch_mapping
=
batch_mapping
,
num_tokens_to_retain
=
num_tokens_to_retain
,
indices_topk
=
indices_topk
,
bh_lens
=
bh_lens
,
page_table
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
MAX_SEL
=
int
(
MAX_SEL
),
HKV
=
H
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
K_TILE
=
K_TILE
,
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
,
)
if
PAD_TO_PAGE_SIZE
:
assert
cu_seqlens_k
is
not
None
assert
indices_topk
.
is_contiguous
()
assert
page_table
.
is_contiguous
()
_prefill_store_topk_pad_kernel
[(
B
,
H
)](
key
=
new_keys
,
value
=
new_vals
,
batch_mapping
=
batch_mapping
,
candidate_counts
=
candidate_counts
,
num_tokens_to_retain
=
num_tokens_to_retain
,
indices
=
indices_topk
,
local_lens
=
bh_lens
,
page_table_flat
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
cu_seqlens_k
=
cu_seqlens_k
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
MAX_SEL
=
int
(
MAX_SEL
),
H
=
H
,
# type: ignore
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
# type: ignore
D
=
D
,
# type: ignore
PAGE_SIZE
=
PAGE_SIZE
,
# type: ignore
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
,
)
@
triton
.
jit
def
_prefill_store_topk_pad_kernel
(
key
,
# [N_total, H, D]
value
,
# [N_total, H, D]
batch_mapping
,
# [B] int32 (local b -> true batch)
candidate_counts
,
# [B] int32
num_tokens_to_retain
,
# [B] int32
indices
,
# [B, MAX_SEL] int32 (across all heads)
local_lens
,
# [B, H] int32 (contiguous)
page_table_flat
,
# [B_total*H*N_LOGICAL_PAGES_MAX] int32
k_cache
,
v_cache
,
# [N_PAGES*PAGE_SIZE, D]
cu_seqlens_k
,
sk_n
,
sk_h
,
sv_n
,
sv_h
,
MAX_SEL
,
# Constexprs
H
:
tl
.
constexpr
,
# number of KV heads
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
b_local
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
D
)
L
=
tl
.
load
(
local_lens
+
b_local
*
H
+
h
)
modulo_page_size
=
L
-
(
L
//
PAGE_SIZE
)
*
PAGE_SIZE
if
modulo_page_size
==
0
:
return
need
=
PAGE_SIZE
-
modulo_page_size
b_true
=
tl
.
load
(
batch_mapping
+
b_local
)
if
b_true
==
TRITON_RESERVED_BATCH
:
return
pt_base
=
(
b_true
*
H
+
h
)
*
N_LOGICAL_PAGES_MAX
written_tokens
=
0
idx
=
tl
.
load
(
num_tokens_to_retain
+
b_local
)
candidate_count
=
tl
.
load
(
candidate_counts
+
b_local
)
this_batch_ctx_len
=
tl
.
load
(
cu_seqlens_k
+
b_local
+
1
)
-
tl
.
load
(
cu_seqlens_k
+
b_local
)
max_additional
=
this_batch_ctx_len
-
L
while
(
written_tokens
<
need
and
idx
<
candidate_count
)
and
(
written_tokens
<
max_additional
):
# candidate head
cand_idx
=
tl
.
load
(
indices
+
b_local
*
MAX_SEL
+
idx
)
cand_h
=
cand_idx
%
H
if
cand_h
==
h
:
tok
=
cand_idx
//
H
pos
=
L
+
written_tokens
lp
=
pos
//
PAGE_SIZE
off
=
pos
-
lp
*
PAGE_SIZE
phys
=
tl
.
load
(
page_table_flat
+
pt_base
+
lp
).
to
(
tl
.
int32
)
dst_row
=
phys
*
PAGE_SIZE
+
off
dst_off
=
dst_row
.
to
(
tl
.
int64
)
*
D
+
offs_d
k_src
=
key
+
tok
*
sk_n
+
h
*
sk_h
+
offs_d
v_src
=
value
+
tok
*
sv_n
+
h
*
sv_h
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
),
)
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
),
)
written_tokens
+=
1
idx
+=
1
tl
.
store
(
local_lens
+
b_local
*
H
+
h
,
L
+
written_tokens
)
@
triton
.
jit
def
_prefill_store_all_kv_kernel
(
key
,
value
,
# [N, H, D] (D contiguous)
cu_seqlens_k
,
# [B + 1] int32
batch_mapping
,
# [B] int32 (local b -> true batch index)
bh_lens
,
# [B * HKV] int32 (UPDATED)
pt_flat
,
# [B_total * HKV * N_LOGICAL_PAGES_MAX] int32 (flattened)
k_cache
,
v_cache
,
# [N_PAGES * PAGE_SIZE, D]
# source strides (elements)
sk_n
,
sk_h
,
sv_n
,
sv_h
,
# constexpr
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
K_TILE
:
tl
.
constexpr
,
# number of (token, head) pairs processed per program
):
pid_b
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
start
=
tl
.
load
(
cu_seqlens_k
+
pid_b
)
end
=
tl
.
load
(
cu_seqlens_k
+
pid_b
+
1
)
num_toks_this_batch
=
end
-
start
if
num_toks_this_batch
<=
0
:
return
total_elems
=
num_toks_this_batch
*
HKV
# base linear index in (token, head) grid for this program
base
=
pid_blk
*
K_TILE
offs_d
=
tl
.
arange
(
0
,
D
)
# Iterate K_TILE elements in this tile
for
i
in
tl
.
range
(
0
,
K_TILE
):
idx
=
base
+
i
if
idx
<
total_elems
:
# map linear idx -> (t, h)
t
=
idx
//
HKV
h
=
idx
-
t
*
HKV
len_idx
=
pid_b
*
HKV
+
h
L0
=
tl
.
load
(
bh_lens
+
len_idx
)
token_idx_in_cache
=
L0
+
t
lp
=
token_idx_in_cache
//
PAGE_SIZE
# logical page
off_in_pg
=
token_idx_in_cache
-
lp
*
PAGE_SIZE
# pos in page
# physical page
b_true
=
tl
.
load
(
batch_mapping
+
pid_b
).
to
(
tl
.
int32
)
pt_base
=
(
b_true
*
HKV
+
h
)
*
N_LOGICAL_PAGES_MAX
phys
=
tl
.
load
(
pt_flat
+
pt_base
+
lp
).
to
(
tl
.
int64
)
row
=
phys
*
PAGE_SIZE
+
off_in_pg
dst_off
=
row
*
D
+
offs_d
n_global
=
(
start
+
t
).
to
(
tl
.
int64
)
# Use strides for non-contiguous [N, H, D] (D stride == 1)
k_src
=
key
+
n_global
*
sk_n
+
h
*
sk_h
+
offs_d
v_src
=
value
+
n_global
*
sv_n
+
h
*
sv_h
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
))
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
))
def
prefill_store_all_kv
(
*
,
new_keys
:
torch
.
Tensor
,
new_values
:
torch
.
Tensor
,
# [N, H_kv, D]
cu_seqlens_k
:
torch
.
Tensor
,
# [B + 1] int32
max_seqlen_k
:
int
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
# [B_total, H_kv, N_LOGICAL_PAGES_MAX] int32
bh_lens
:
torch
.
Tensor
,
# [B, H_kv] int32 (UPDATED)
batch_mapping
:
torch
.
Tensor
,
# [B] int32 (local->true)
PAGE_SIZE
:
int
,
K_TILE
:
int
=
32
,
# how many (token, head) pairs per program
):
assert
new_keys
.
stride
(
-
1
)
==
1
and
new_values
.
stride
(
-
1
)
==
1
,
(
"last dim must be contiguous"
)
assert
page_table
.
is_contiguous
(),
"page table must be contiguous"
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous"
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous"
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
N
,
HKV
,
D
=
new_keys
.
shape
B
=
batch_mapping
.
shape
[
0
]
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
sk_n
,
sk_h
,
_
=
new_keys
.
stride
()
sv_n
,
sv_h
,
_
=
new_values
.
stride
()
n_tiles
=
(
max_seqlen_k
*
HKV
+
K_TILE
-
1
)
//
K_TILE
grid
=
(
B
,
n_tiles
)
_prefill_store_all_kv_kernel
[
grid
](
new_keys
,
new_values
,
cu_seqlens_k
,
batch_mapping
,
bh_lens
,
page_table
,
k_cache
,
v_cache
,
sk_n
=
sk_n
,
sk_h
=
sk_h
,
sv_n
=
sv_n
,
sv_h
=
sv_h
,
HKV
=
HKV
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
-
1
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
K_TILE
=
K_TILE
,
)
bh_lens
+=
cu_seqlens_k
.
diff
()[:,
None
]
@
triton
.
jit
def
_decode_store_kv_kernel
(
key
,
value
,
batch_mapping
,
# [B] int32
bh_lens
,
# [B*HKV] int32
page_table
,
# [B_total*HKV*N_LOGICAL_PAGES_MAX]
k_cache
,
v_cache
,
# [N_PAGES*PAGE_SIZE, D]
sk_b
,
sk_h
,
sv_b
,
sv_h
,
HKV
:
tl
.
constexpr
,
N_LOGICAL_PAGES_MAX
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
TRITON_RESERVED_BATCH
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
mapped_b
=
tl
.
load
(
batch_mapping
+
pid_b
)
if
mapped_b
==
TRITON_RESERVED_BATCH
:
return
offs_d
=
tl
.
arange
(
0
,
D
)
length
=
tl
.
load
(
bh_lens
+
pid_b
*
HKV
+
h
)
logical_page
=
length
//
PAGE_SIZE
internal_offset
=
length
-
logical_page
*
PAGE_SIZE
pt_base
=
(
mapped_b
*
HKV
+
h
)
*
N_LOGICAL_PAGES_MAX
physical_page
=
tl
.
load
(
page_table
+
pt_base
+
logical_page
).
to
(
tl
.
int64
)
dst_row
=
physical_page
*
PAGE_SIZE
+
internal_offset
# Source addressing using strides (D stride == 1)
k_src
=
key
+
pid_b
*
sk_b
+
h
*
sk_h
+
offs_d
v_src
=
value
+
pid_b
*
sv_b
+
h
*
sv_h
+
offs_d
dst_off
=
dst_row
*
D
+
offs_d
tl
.
store
(
k_cache
+
dst_off
,
tl
.
load
(
k_src
))
tl
.
store
(
v_cache
+
dst_off
,
tl
.
load
(
v_src
))
tl
.
store
(
bh_lens
+
pid_b
*
HKV
+
h
,
length
+
1
)
def
decode_store_kv
(
*
,
key
:
torch
.
Tensor
,
# [B, HKV, D]
value
:
torch
.
Tensor
,
# [B, HKV, D]
batch_mapping
:
torch
.
Tensor
,
# [B] int32
bh_lens
:
torch
.
Tensor
,
# [B, HKV] or flattened [B*HKV] int32
page_table
:
torch
.
Tensor
,
# [B_total, HKV, N_LOGICAL_PAGES_MAX] int32
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
# [N_PAGES*PAGE_SIZE, D]
PAGE_SIZE
:
int
,
TRITON_RESERVED_BATCH
:
int
=
None
,
):
assert
key
.
shape
==
value
.
shape
and
key
.
ndim
==
3
,
"key/value must be [B, HKV, D]"
B
,
HKV
,
D
=
key
.
shape
assert
key
.
stride
(
-
1
)
==
1
and
value
.
stride
(
-
1
)
==
1
,
(
"key/value last dim must be contiguous."
)
assert
page_table
.
is_contiguous
(),
"page table must be contiguous."
assert
bh_lens
.
is_contiguous
(),
"bh_lens must be contiguous."
assert
batch_mapping
.
is_contiguous
(),
"batch mapping must be contiguous."
assert
k_cache
.
is_contiguous
()
and
v_cache
.
is_contiguous
()
assert
(
D
&
(
D
-
1
))
==
0
,
"D must be a power of 2"
sk_b
,
sk_h
,
_
=
key
.
stride
()
sv_b
,
sv_h
,
_
=
value
.
stride
()
grid
=
(
int
(
batch_mapping
.
shape
[
0
]),
HKV
,
)
_decode_store_kv_kernel
[
grid
](
key
=
key
,
value
=
value
,
batch_mapping
=
batch_mapping
,
bh_lens
=
bh_lens
,
page_table
=
page_table
,
k_cache
=
k_cache
,
v_cache
=
v_cache
,
sk_b
=
sk_b
,
sk_h
=
sk_h
,
sv_b
=
sv_b
,
sv_h
=
sv_h
,
HKV
=
HKV
,
N_LOGICAL_PAGES_MAX
=
page_table
.
shape
[
2
],
D
=
D
,
PAGE_SIZE
=
PAGE_SIZE
,
TRITON_RESERVED_BATCH
=
TRITON_RESERVED_BATCH
if
TRITON_RESERVED_BATCH
is
not
None
else
_TRITON_RESERVED_BATCH
,
)
vllm/kvprune/kv_cache/write_page_table.py
0 → 100644
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
def
scatter_to_page_table
(
add_pages
:
torch
.
Tensor
,
# [L, H] int32
new_phys_pages
:
torch
.
Tensor
,
# [N]
curr_pages
:
torch
.
Tensor
,
# [L, H] int32
page_table
:
torch
.
Tensor
,
# [L, H, max_pages_per_head] int32, NOT assumed contiguous globally
max_pages_per_head
:
int
,
):
"""
Append newly allocated physical pages into a layered page table via Triton.
For each (layer ``l``, head ``h``):
Args:
:param add_pages:
Tensor of shape ``[L, H]`` (int32) indicating how many pages to
append for each (layer, head).
:param new_phys_pages:
1D tensor of shape ``[N]`` (int32) containing physical page IDs
for all (layer, head) pairs, concatenated in row-major (L, H)
order. ``N`` must equal ``add_pages.sum()``.
:param curr_pages:
Tensor of shape ``[L, H]`` (int32) with the current logical page
counts per (layer, head) before this update.
:param page_table:
Tensor of shape ``[L, H, max_pages_per_head]`` (int32) holding
the logical to physical page mapping. The last dimension is
logically indexed as logical_page ∈ [0, max_pages_per_head).
:param max_pages_per_head:
Maximum number of logical pages permitted per (layer, head). The
kernel skips writes beyond this bound.
Returns:
None. The function updates ``page_table`` in-place.
"""
L
,
H
=
add_pages
.
shape
if
L
==
0
or
H
==
0
:
return
add_flat
=
add_pages
.
to
(
torch
.
int32
).
contiguous
().
view
(
-
1
)
curr_flat
=
curr_pages
.
to
(
torch
.
int32
).
contiguous
().
view
(
-
1
)
cum_page_heads
=
torch
.
empty
(
L
*
H
+
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
cum_page_heads
[
0
]
=
0
torch
.
cumsum
(
add_flat
,
0
,
out
=
cum_page_heads
[
1
:])
stride_pl
,
stride_ph
,
stride_pp
=
page_table
.
stride
()
grid
=
(
L
,
H
)
_scatter_pages_kernel_lh
[
grid
](
add_flat
,
cum_page_heads
,
new_phys_pages
,
curr_flat
,
page_table
,
stride_pl
,
stride_ph
,
stride_pp
,
L
=
L
,
H
=
H
,
max_pages_per_head
=
max_pages_per_head
,
)
@
triton
.
jit
def
_scatter_pages_kernel_lh
(
add_pages
,
# int32 [L*H]
cum_page_heads
,
# int32 [L*H], base offset in flat_new_phys per (l,h)
flat_new_phys
,
# int32 [total_pages]
curr_pages
,
# int32 [L*H], existing logical pages per (l,h)
page_table_ptr
,
# int32* base pointer to page_table
stride_pl
,
# int, stride for layer dim
stride_ph
,
# int, stride for head dim
stride_pp
,
# int, stride for page dim
L
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
max_pages_per_head
:
tl
.
constexpr
,
):
layer_idx
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
if
layer_idx
>=
L
or
h
>=
H
:
return
lh
=
layer_idx
*
H
+
h
ap
=
tl
.
load
(
add_pages
+
lh
)
if
ap
<=
0
:
return
base
=
tl
.
load
(
cum_page_heads
+
lh
)
cp
=
tl
.
load
(
curr_pages
+
lh
)
# Append ap pages: logical pages [cp .. cp+ap)
for
i
in
tl
.
range
(
0
,
ap
):
phys
=
tl
.
load
(
flat_new_phys
+
base
+
i
)
lp
=
cp
+
i
if
lp
<
max_pages_per_head
:
offset
=
layer_idx
*
stride_pl
+
h
*
stride_ph
+
lp
*
stride_pp
tl
.
store
(
page_table_ptr
+
offset
,
phys
)
# TODO: write reclaim kernel
@
triton
.
jit
def
reclaim_page_kernel
():
pass
def
reclaim_pages
(
batch_index
:
int
,
bh_seq_lens
:
torch
.
Tensor
,
bh_num_pages
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
):
pass
vllm/kvprune/kvprune_to_vllm.md
0 → 100644
View file @
2b7160c6
# KV-prune 与上游 vLLM 的集成说明
本文说明:
**剪枝/压缩(Compactor)功能**
在「官网 vLLM 主仓库」里改动了哪些位置、是否只有少量文件、以及随 vLLM 版本升级时如何预期合并成本。
## 1. 是否「仅仅」改了少数几个脚本?
**核心运行时接线**
确实集中在少数几个
**非**
`vllm/kvprune/`
下的文件;功能主体在
`vllm/kvprune/`
包内独立维护。
| 路径 | 作用简述 |
|------|-----------|
|
`vllm/env_override.py`
| 在
`import vllm`
最早阶段设置与 kvprune 相关的默认环境变量(如 v1 多进程默认、压缩默认开关、可选释放 v1 KV 等)。 |
|
`vllm/__init__.py`
| 对外导出
`CompressionParams`
(懒加载至
`vllm.kvprune.integration.compression_params`
)。 |
|
`vllm/entrypoints/llm.py`
|
`kvprune_compression`
参数、
`generate(..., compression=...)`
、v1
`enforce_eager`
/
`num_gpu_blocks_override`
策略、懒加载 compactor、委托
`compressed_generate`
。 |
|
`vllm/v1/worker/gpu_worker.py`
|
`kvprune_v1_compressed_generate`
:供
`collective_rpc`
调用的 TP 多卡压缩生成入口。 |
|
`tests/conftest.py`
| 测试在导入 vLLM 前覆盖部分
`VLLM_KVPRUNE_*`
默认值,避免全量测试默认走压缩路径。 |
|
`vllm\vllm\envs.py`
| envs.py 中对 VLLM_KVPRUNE_
*
的集中注册 |
**此外(可选/示例,非引擎必需):**
-
`examples/offline_inference/`
下若干
`*kvprune*`
示例脚本:演示用法,不参与核心引擎加载。
**结论:**
-
**「官网 vLLM 主包」里与 kvprune 强相关的改动,主要就是上表 4 个文件 + 测试根配置**
(若把测试也算进「集成面」,共 5 处常见提法)。
-
**算法、Compactor、TP 内嵌 runner 等**
均在
`vllm/kvprune/`
(及该目录下的
`integration/`
)中,与上游 diff 相对隔离。
## 2. 随 vLLM 版本更新,是否「很容易」同步剪枝压缩功能?
**相对容易的部分:**
-
**集成面小**
:合并冲突主要出现在上述少数文件,而不是遍布整个 executor / attention / model 层。
-
**逻辑内聚**
:大量代码在
`vllm/kvprune/`
,可整体移植或
`git`
三方合并时以子树为主处理。
**仍需人工跟进的点(不能假设「自动无痛」):**
-
**`entrypoints/llm.py` 属于高频变更文件**
:上游每次大版本可能重构
`LLM`
构造参数、
`generate`
签名或引擎初始化;需要
**逐次解决冲突**
并回归压缩路径。
-
**`v1/worker/gpu_worker.py`**
同样会随 executor / RPC 接口变动;
`collective_rpc`
方法名或 worker 基类若有变化,需对齐。
-
**`env_override.py`**
若上游调整导入顺序或新增全局默认环境变量,需避免覆盖冲突或行为打架。
-
**vLLM v1 内部 API**
(如
`worker.get_model()`
、
`vllm_config`
结构)若变更,
`vllm/kvprune/integration/*`
也可能要跟着改——这类改动
**不在**
「仅 5 个文件」里,但仍是
**集成层**
维护成本。
**建议同步流程(简版):**
1.
在新上游 tag 上先合并/应用
`vllm/kvprune/`
目录。
2.
再手动合并上述 4 个主包文件 +
`tests/conftest.py`
。
3.
跑与 kvprune 相关的测试与至少一条离线
`compression`
示例。
4.
关注发行说明中
`LLM`
、
`EngineArgs`
、
`gpu_worker`
、多进程默认的破坏性变更。
## 3. 与「深度改内核」方案的区别
当前设计
**没有**
在
`model_executor`
的统一注意力路径上大规模插入 kvprune 钩子(相关辅助逻辑主要在
`vllm/kvprune`
内部)。因此:
-
**上游同步时**
,通常不必与 FlashAttention / 每层模型代码逐文件对打;
-
**代价是**
:功能边界以「共享权重 + compactor 引擎 + 可选 TP RPC」为主,与「原生 KV 算子级一体化」的改动面不同。
---
## 4. 目录重建说明(与 `compactor-vllm` 对齐)
`vllm/kvprune/`
以
`vllm/compactor-vllm/src/compactor_vllm/`
为算法与内核基线整体迁入(
`compactor_vllm`
→
`vllm.kvprune`
),再叠加上游集成层:
-
**集成**
:
`integration/*`
(
`compressed_generate`
、
`compactor_shared`
、
`config_adapter`
、
`v1_tp_runner`
、
`weight_tie`
等)仍负责「同
`LLM.generate`
前端、双后端」。
-
**TP / 调度**
:
`core/model_runner.py`
、
`utils/tp_utils.py`
、
`utils/tp_collectives.py`
、
`utils/kv_dist.py`
等保留 vLLM 内嵌 TP 与
`collective_rpc`
路径。
-
**三种 attention 模式**
:
`config/engine_config.py`
的
`KvpruneAttentionSchedule`
+
`integration/config_adapter.py`
的环境变量解析;
`layers/attention.py`
+
`attention/fa_paged_bridge.py`
实现
`fa_triton`
/
`pdtriton`
/
`pdfa`
。
临时备份目录
`vllm/kvprune_legacy_save/`
可在确认无误后手动删除。
---
*文档随仓库维护;若集成文件列表有增删,请同步更新本节表格。*
vllm/kvprune/layers/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layers from upstream compactor (attention, linear, MoE, …).
Prefer importing concrete modules, e.g. ``from vllm.kvprune.layers.attention import ...``.
"""
__all__
:
list
[
str
]
=
[]
vllm/kvprune/layers/activation.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
class
SiluAndMul
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
# @torch.compile
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
y
=
x
.
chunk
(
2
,
-
1
)
return
F
.
silu
(
x
)
*
y
vllm/kvprune/layers/attention.py
0 → 100644
View file @
2b7160c6
from
typing
import
Optional
import
torch
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
from
torch
import
nn
from
vllm.kvprune.attention.fa_paged_bridge
import
(
flash_decode_from_paged
,
flash_prefill_from_paged
,
)
from
vllm.kvprune.attention.sparse_decode_kernel
import
head_sparse_decode_attention
from
vllm.kvprune.attention.sparse_varlen_kernel
import
(
causal_sparse_varlen_with_cache
,
)
from
vllm.kvprune.compression.common
import
extract_and_store_top_kv
from
vllm.kvprune.config.engine_config
import
KvpruneAttentionSchedule
from
vllm.kvprune.kv_cache.store_kv_cache
import
decode_store_kv
,
prefill_store_all_kv
from
vllm.kvprune.utils.context
import
Context
,
get_context
from
vllm.kvprune.utils.helpers
import
maybe_execute_in_stream
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scale
,
num_kv_heads
,
):
super
().
__init__
()
self
.
num_heads
:
int
=
num_heads
self
.
head_dim
=
head_dim
self
.
scale
:
float
=
scale
self
.
num_kv_heads
=
int
(
num_kv_heads
)
self
.
k_cache
:
Optional
[
torch
.
Tensor
]
=
None
self
.
v_cache
:
Optional
[
torch
.
Tensor
]
=
None
self
.
page_table
:
Optional
[
torch
.
Tensor
]
=
None
self
.
bh_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
page_size
:
Optional
[
int
]
=
None
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
scores
:
Optional
[
torch
.
Tensor
]
=
None
,
):
context
:
Context
=
get_context
()
batch_mapping
=
context
.
batch_mapping
seq_lens
=
(
None
if
self
.
bh_seq_lens
is
None
else
self
.
bh_seq_lens
.
index_select
(
0
,
batch_mapping
).
contiguous
()
)
sched
=
context
.
attention_schedule
use_triton_prefill_attn
=
(
sched
==
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
)
use_fa_decode
=
sched
==
KvpruneAttentionSchedule
.
PDFA
if
context
.
is_prefill
:
seq_lens_copy
=
seq_lens
.
clone
()
if
seq_lens
is
not
None
else
None
if
(
self
.
k_cache
is
not
None
and
context
.
do_compression
and
scores
is
not
None
):
compression_context
=
context
.
compression_context
assert
compression_context
is
not
None
maybe_execute_in_stream
(
extract_and_store_top_kv
,
scores
=
scores
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_k_len
=
context
.
max_seqlen_k
,
top_k
=
compression_context
.
max_tokens_to_retain
,
H
=
int
(
self
.
num_kv_heads
),
new_keys
=
k
,
new_vals
=
v
,
num_tokens_to_retain
=
compression_context
.
batch_tokens_to_retain
,
page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
PAGE_SIZE
=
self
.
page_size
,
PAD_TO_PAGE_SIZE
=
True
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
elif
self
.
k_cache
is
not
None
:
maybe_execute_in_stream
(
prefill_store_all_kv
,
new_keys
=
k
,
new_values
=
v
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
max_seqlen_k
=
context
.
max_seqlen_k
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
page_table
=
self
.
page_table
,
bh_lens
=
seq_lens
,
batch_mapping
=
batch_mapping
,
PAGE_SIZE
=
self
.
page_size
,
STORE_STREAM
=
context
.
STORE_STREAM
,
)
if
use_triton_prefill_attn
:
if
context
.
do_compression
and
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
assert
seq_lens_copy
is
not
None
o
=
causal_sparse_varlen_with_cache
(
q
,
k
,
v
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh
=
seq_lens_copy
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_q
=
context
.
max_seqlen_q
,
max_seqlen_k_cache
=
context
.
max_bh_len
,
HKV
=
int
(
self
.
num_kv_heads
),
PAGE_SIZE
=
self
.
page_size
,
sm_scale
=
self
.
scale
,
)
elif
context
.
do_compression
:
if
context
.
STORE_STREAM
is
not
None
:
torch
.
cuda
.
current_stream
().
wait_stream
(
context
.
STORE_STREAM
)
assert
seq_lens_copy
is
not
None
o
=
flash_prefill_from_paged
(
q
,
k
,
v
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh_before
=
seq_lens_copy
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_q
=
context
.
max_seqlen_q
,
PAGE_SIZE
=
self
.
page_size
,
HKV
=
int
(
self
.
num_kv_heads
),
sm_scale
=
self
.
scale
,
)
else
:
o
=
flash_attn_varlen_func
(
q
,
k
,
v
,
max_seqlen_q
=
context
.
max_seqlen_q
,
cu_seqlens_q
=
context
.
cu_seqlens_q
,
max_seqlen_k
=
context
.
max_seqlen_k
,
cu_seqlens_k
=
context
.
cu_seqlens_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
assert
self
.
k_cache
is
not
None
,
"KV Cache must be initialized for decoding"
decode_store_kv
(
key
=
k
,
value
=
v
,
batch_mapping
=
batch_mapping
,
bh_lens
=
seq_lens
,
page_table
=
self
.
page_table
,
k_cache
=
self
.
k_cache
,
v_cache
=
self
.
v_cache
,
PAGE_SIZE
=
self
.
page_size
,
)
if
use_fa_decode
:
assert
seq_lens
is
not
None
o
=
flash_decode_from_paged
(
q
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens_bh
=
seq_lens
,
global_page_table
=
self
.
page_table
,
batch_mapping
=
batch_mapping
,
PAGE_SIZE
=
self
.
page_size
,
HKV
=
int
(
self
.
num_kv_heads
),
sm_scale
=
self
.
scale
,
)
else
:
o
=
head_sparse_decode_attention
(
q
,
self
.
k_cache
,
self
.
v_cache
,
seq_lens
,
self
.
page_table
,
batch_mapping
,
int
(
self
.
num_kv_heads
),
self
.
page_size
,
self
.
scale
,
key_split
=
context
.
key_split
,
)
if
self
.
bh_seq_lens
is
not
None
:
longbm
=
batch_mapping
.
to
(
device
=
self
.
bh_seq_lens
.
device
,
dtype
=
torch
.
long
)
maybe_execute_in_stream
(
self
.
bh_seq_lens
.
index_copy_
,
0
,
longbm
,
seq_lens
,
STORE_STREAM
=
context
.
STORE_STREAM
if
context
.
is_prefill
else
None
,
)
return
o
vllm/kvprune/layers/embed_head.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
vllm.kvprune.utils.context
import
get_context
from
vllm.kvprune.utils.tp_collectives
import
tensor_parallel_all_reduce
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
from
torch
import
nn
class
VocabParallelEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
):
super
().
__init__
()
self
.
tp_rank
=
tensor_parallel_rank_for_sharding
()
self
.
tp_size
=
tensor_parallel_world_size_for_sharding
()
assert
num_embeddings
%
self
.
tp_size
==
0
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings_per_partition
=
self
.
num_embeddings
//
self
.
tp_size
self
.
vocab_start_idx
=
self
.
num_embeddings_per_partition
*
self
.
tp_rank
self
.
vocab_end_idx
=
self
.
vocab_start_idx
+
self
.
num_embeddings_per_partition
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
embedding_dim
)
)
self
.
weight
.
weight_loader
=
self
.
weight_loader
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
0
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
0
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
tp_size
>
1
:
mask
=
(
x
>=
self
.
vocab_start_idx
)
&
(
x
<
self
.
vocab_end_idx
)
x
=
mask
*
(
x
-
self
.
vocab_start_idx
)
y
=
F
.
embedding
(
x
,
self
.
weight
)
if
self
.
tp_size
>
1
:
y
=
mask
.
unsqueeze
(
1
)
*
y
tensor_parallel_all_reduce
(
y
)
return
y
class
ParallelLMHead
(
VocabParallelEmbedding
):
"""LM head with TP vocab sharding.
When embedded in a vLLM worker, logits must be gathered on the **tensor-
parallel** process group (see :func:`~vllm.distributed.communication_op.tensor_model_parallel_gather`),
not the default :func:`torch.distributed.gather` — otherwise shard order / group
mismatch yields garbage logits and decoded gibberish.
After gather, logits are truncated to ``org_vocab_size`` (HF tokenizer vocab),
matching :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
removal of padded vocabulary columns.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
bias
:
bool
=
False
,
*
,
org_vocab_size
:
int
|
None
=
None
,
):
assert
not
bias
super
().
__init__
(
num_embeddings
,
embedding_dim
)
# Original (unpadded) vocab size for logits truncation; defaults to num_embeddings.
self
.
org_vocab_size
=
(
int
(
org_vocab_size
)
if
org_vocab_size
is
not
None
else
num_embeddings
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
context
=
get_context
()
if
context
.
is_prefill
:
cu
=
context
.
cu_seqlens_q
last_indices
=
(
cu
[
1
:]
-
1
).
to
(
torch
.
long
)
n_tok
=
x
.
shape
[
0
]
if
n_tok
>
0
:
last_indices
=
last_indices
.
clamp
(
min
=
0
,
max
=
n_tok
-
1
)
x
=
x
[
last_indices
].
contiguous
()
logits
=
F
.
linear
(
x
,
self
.
weight
)
if
self
.
tp_size
>
1
:
logits
=
self
.
_gather_logits_tp
(
logits
)
if
logits
is
not
None
and
logits
.
shape
[
-
1
]
>
self
.
org_vocab_size
:
logits
=
logits
[...,
:
self
.
org_vocab_size
]
return
logits
def
_gather_logits_tp
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
try
:
from
vllm.distributed.parallel_state
import
model_parallel_is_initialized
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_gather
,
)
if
model_parallel_is_initialized
():
return
tensor_model_parallel_gather
(
logits
,
dst
=
0
,
dim
=-
1
)
except
Exception
:
pass
all_logits
=
(
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
tp_size
)]
if
self
.
tp_rank
==
0
else
None
)
dist
.
gather
(
logits
,
all_logits
,
0
)
return
torch
.
cat
(
all_logits
,
-
1
)
if
self
.
tp_rank
==
0
else
None
vllm/kvprune/layers/layernorm.py
0 → 100644
View file @
2b7160c6
import
torch
from
torch
import
nn
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
# @torch.compile
def
rms_forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
.
mul_
(
torch
.
rsqrt
(
var
+
self
.
eps
))
x
=
x
.
to
(
orig_dtype
).
mul_
(
self
.
weight
)
return
x
# @torch.compile
def
add_rms_forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
orig_dtype
=
x
.
dtype
x
=
x
.
float
().
add_
(
residual
.
float
())
residual
=
x
.
to
(
orig_dtype
)
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
.
mul_
(
torch
.
rsqrt
(
var
+
self
.
eps
))
x
=
x
.
to
(
orig_dtype
).
mul_
(
self
.
weight
)
return
x
,
residual
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
residual
is
None
:
return
self
.
rms_forward
(
x
)
else
:
return
self
.
add_rms_forward
(
x
,
residual
)
vllm/kvprune/layers/linear.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
vllm.kvprune.utils.tp_collectives
import
tensor_parallel_all_reduce
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
from
torch
import
nn
def
divide
(
numerator
,
denominator
):
assert
numerator
%
denominator
==
0
return
numerator
//
denominator
class
LinearBase
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
tp_dim
:
int
|
None
=
None
,
):
super
().
__init__
()
self
.
tp_dim
=
tp_dim
self
.
tp_rank
=
tensor_parallel_rank_for_sharding
()
self
.
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
output_size
,
input_size
))
self
.
weight
.
weight_loader
=
self
.
weight_loader
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_size
))
self
.
bias
.
weight_loader
=
self
.
weight_loader
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
super
().
__init__
(
input_size
,
output_size
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
ColumnParallelLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
tp_size
=
tensor_parallel_world_size_for_sharding
()
super
().
__init__
(
input_size
,
divide
(
output_size
,
tp_size
),
bias
,
0
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
self
.
tp_dim
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
self
.
tp_dim
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
def
__init__
(
self
,
input_size
:
int
,
output_sizes
:
list
[
int
],
bias
:
bool
=
False
,
):
self
.
output_sizes
=
output_sizes
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
int
):
param_data
=
param
.
data
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
self
.
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
self
.
tp_size
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
self
.
tp_dim
)[
self
.
tp_rank
]
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
def
__init__
(
self
,
hidden_size
:
int
,
head_size
:
int
,
total_num_heads
:
int
,
total_num_kv_heads
:
int
|
None
=
None
,
bias
:
bool
=
False
,
):
tp_size
=
tensor_parallel_world_size_for_sharding
()
total_num_kv_heads
=
total_num_kv_heads
or
total_num_heads
self
.
head_size
=
head_size
self
.
num_heads
=
divide
(
total_num_heads
,
tp_size
)
self
.
num_kv_heads
=
divide
(
total_num_kv_heads
,
tp_size
)
output_size
=
(
total_num_heads
+
2
*
total_num_kv_heads
)
*
self
.
head_size
super
().
__init__
(
hidden_size
,
output_size
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
str
):
param_data
=
param
.
data
assert
loaded_shard_id
in
[
"q"
,
"k"
,
"v"
]
if
loaded_shard_id
==
"q"
:
shard_size
=
self
.
num_heads
*
self
.
head_size
shard_offset
=
0
elif
loaded_shard_id
==
"k"
:
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_offset
=
self
.
num_heads
*
self
.
head_size
else
:
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_offset
=
(
self
.
num_heads
*
self
.
head_size
+
self
.
num_kv_heads
*
self
.
head_size
)
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
self
.
tp_dim
)[
self
.
tp_rank
]
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
LinearBase
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
):
tp_size
=
tensor_parallel_world_size_for_sharding
()
super
().
__init__
(
divide
(
input_size
,
tp_size
),
output_size
,
bias
,
1
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
param_data
=
param
.
data
shard_size
=
param_data
.
size
(
self
.
tp_dim
)
start_idx
=
self
.
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
self
.
tp_dim
,
start_idx
,
shard_size
)
param_data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
y
=
F
.
linear
(
x
,
self
.
weight
,
self
.
bias
if
self
.
tp_rank
==
0
else
None
)
if
self
.
tp_size
>
1
:
tensor_parallel_all_reduce
(
y
)
return
y
vllm/kvprune/layers/moe.py
0 → 100644
View file @
2b7160c6
import
torch
import
torch.distributed
as
dist
from
vllm.kvprune.triton_kernels.matmul_ogs
import
matmul_ogs
from
vllm.kvprune.utils.tp_collectives
import
tensor_parallel_all_reduce
from
vllm.kvprune.utils.tp_utils
import
(
tensor_parallel_rank_for_sharding
,
tensor_parallel_world_size_for_sharding
,
)
from
torch
import
nn
def
divide
(
numerator
,
denominator
):
assert
numerator
%
denominator
==
0
return
numerator
//
denominator
class
TritonFusedMoeLinearBase
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
tp_dim
:
int
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
tp_dim
=
tp_dim
self
.
tp_rank
=
tensor_parallel_rank_for_sharding
()
self
.
tp_size
=
tensor_parallel_world_size_for_sharding
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
num_experts
=
num_experts
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
in_features
,
out_features
)).
transpose
(
-
1
,
-
2
)
)
self
.
weight
.
weight_loader
=
self
.
weight_loader
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
out_features
)))
self
.
bias
.
weight_loader
=
self
.
weight_loader
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedTritonFusedMoeLinear
(
TritonFusedMoeLinearBase
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
in_features
,
out_features
,
num_experts
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
):
param
.
data
[
expert_idx
].
copy_
(
loaded_weight
,
non_blocking
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
w
=
self
.
weight
.
transpose
(
-
1
,
-
2
)
assert
w
.
is_contiguous
()
return
matmul_ogs
(
x
,
self
.
weight
,
self
.
bias
,
**
kwargs
,
)
class
RowParallelTritonFusedMoeLinear
(
TritonFusedMoeLinearBase
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
)
->
None
:
tp_size
=
(
tensor_parallel_world_size_for_sharding
()
if
dist
.
is_initialized
()
else
1
)
super
().
__init__
(
divide
(
in_features
,
tp_size
),
out_features
,
num_experts
,
bias
,
2
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
):
shard_size
=
param
.
size
(
2
)
start_idx
=
self
.
tp_rank
*
shard_size
local_shard
=
loaded_weight
[:,
start_idx
:
start_idx
+
shard_size
]
param
.
data
[
expert_idx
].
copy_
(
local_shard
,
non_blocking
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
w
=
self
.
weight
.
transpose
(
-
1
,
-
2
)
assert
w
.
is_contiguous
()
y
=
matmul_ogs
(
x
,
w
,
self
.
bias
,
**
kwargs
,
)
if
self
.
tp_size
>
1
:
tensor_parallel_all_reduce
(
y
)
return
y
class
ColumnParallelTritonFusedMoeLinear
(
TritonFusedMoeLinearBase
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
num_experts
:
int
,
bias
:
bool
=
False
,
)
->
None
:
tp_size
=
(
tensor_parallel_world_size_for_sharding
()
if
dist
.
is_initialized
()
else
1
)
super
().
__init__
(
in_features
,
divide
(
out_features
,
tp_size
),
num_experts
,
bias
,
1
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
):
shard_size
=
param
.
size
(
1
)
start_idx
=
self
.
tp_rank
*
shard_size
local_shard
=
loaded_weight
[
start_idx
:
start_idx
+
shard_size
,
:]
param
.
data
[
expert_idx
].
copy_
(
local_shard
,
non_blocking
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
w
=
self
.
weight
.
transpose
(
-
1
,
-
2
)
assert
w
.
is_contiguous
()
y
=
matmul_ogs
(
x
,
w
,
self
.
bias
,
**
kwargs
,
)
return
y
class
MergedColumnParallelTritonFusedMoeLinear
(
ColumnParallelTritonFusedMoeLinear
):
def
__init__
(
self
,
in_features
:
int
,
out_feature_list
:
list
[
int
],
num_experts
:
int
,
bias
:
bool
=
False
,
):
self
.
out_feature_list
=
out_feature_list
super
().
__init__
(
in_features
,
sum
(
out_feature_list
),
num_experts
,
bias
)
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_idx
:
int
,
shard_id
:
int
,
):
param_data
=
param
.
data
shard_offset
=
sum
(
self
.
out_feature_list
[:
shard_id
])
//
self
.
tp_size
shard_size
=
self
.
out_feature_list
[
shard_id
]
//
self
.
tp_size
param_data
=
param_data
.
narrow
(
self
.
tp_dim
,
shard_offset
,
shard_size
)
local_weight
=
loaded_weight
.
chunk
(
self
.
tp_size
,
dim
=
self
.
tp_dim
-
1
)[
self
.
tp_rank
]
param_data
[
expert_idx
].
copy_
(
local_weight
,
non_blocking
=
True
)
vllm/kvprune/layers/rotary_embedding.py
0 → 100644
View file @
2b7160c6
import
math
from
functools
import
lru_cache
import
torch
from
torch
import
nn
def
apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x1
,
x2
=
torch
.
chunk
(
x
.
float
(),
2
,
dim
=-
1
)
y1
=
x1
*
cos
-
x2
*
sin
y2
=
x2
*
cos
+
x1
*
sin
return
torch
.
cat
((
y1
,
y2
),
dim
=-
1
).
to
(
x
.
dtype
)
class
RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
rope_scaling
:
tuple
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
assert
rotary_dim
==
head_size
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
rotary_dim
)
)
if
rope_scaling
is
not
None
:
(
rope_type
,
factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position_embeddings
,
)
=
rope_scaling
assert
rope_type
==
"llama3"
old_context_len
=
original_max_position_embeddings
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
wavelen
=
2
*
math
.
pi
/
inv_freq
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
).
unsqueeze_
(
1
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
# @torch.compile
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query
=
apply_rotary_emb
(
query
,
cos
,
sin
)
key
=
apply_rotary_emb
(
key
,
cos
,
sin
)
return
query
,
key
@
lru_cache
(
1
)
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
float
,
rope_scaling
:
tuple
|
None
=
None
,
):
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
rope_scaling
)
return
rotary_emb
vllm/kvprune/layers/sampler.py
0 → 100644
View file @
2b7160c6
import
torch
from
torch
import
nn
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
# @torch.compile
def
forward
(
self
,
logits
:
torch
.
Tensor
,
temperatures
:
torch
.
Tensor
):
temps
=
temperatures
.
view
(
-
1
)
scaled
=
logits
.
float
()
greedy_mask
=
temps
==
0.0
sample_mask
=
~
greedy_mask
if
sample_mask
.
any
():
temps_sample
=
temps
[
sample_mask
].
unsqueeze
(
-
1
)
# [B_sample, 1]
scaled_sample
=
scaled
[
sample_mask
].
div
(
temps_sample
)
# temperature scaling
E
=
torch
.
empty_like
(
scaled_sample
).
exponential_
(
1
).
clamp_min_
(
1e-10
).
log
()
scaled_sample
=
scaled_sample
-
E
scaled
=
scaled
.
clone
()
scaled
[
sample_mask
]
=
scaled_sample
return
scaled
.
argmax
(
dim
=-
1
)
vllm/kvprune/layers/triton_helpers.py
0 → 100644
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_masked_index_select_kernel
(
X_ptr
,
IDX_ptr
,
OUT_ptr
,
N
,
stride_xn
,
stride_xh
,
stride_ob
,
stride_oh
,
):
b
=
tl
.
program_id
(
0
)
# which output row (0..B-1)
h
=
tl
.
program_id
(
1
)
idx
=
tl
.
load
(
IDX_ptr
+
b
)
# int32
valid
=
(
idx
>=
0
)
&
(
idx
<
N
)
out_ptrs
=
OUT_ptr
+
b
*
stride_ob
+
h
*
stride_oh
if
not
valid
:
tl
.
store
(
out_ptrs
,
0
)
else
:
x_ptrs
=
X_ptr
+
idx
*
stride_xn
+
h
*
stride_xh
vals
=
tl
.
load
(
x_ptrs
)
tl
.
store
(
out_ptrs
,
vals
)
def
masked_index_select_triton_dim0
(
input
:
torch
.
Tensor
,
index
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
X: [N, H] : contiguous in the H dimension
b_m: [B] int32/int64 on same device; out-of-range -> zeros)
Returns: [B, H]
"""
assert
input
.
ndim
==
2
and
index
.
ndim
==
1
N
,
H
=
input
.
shape
B
=
index
.
numel
()
out
=
torch
.
empty
((
B
,
H
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
_masked_index_select_kernel
[(
B
,
H
)](
input
,
index
,
out
,
N
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
)
return
out
@
triton
.
jit
def
_masked_index_copy_kernel
(
DST_ptr
,
IDX_ptr
,
SRC_ptr
,
N
,
stride_dn
,
stride_dh
,
stride_sb
,
stride_sh
,
):
b
=
tl
.
program_id
(
0
)
h
=
tl
.
program_id
(
1
)
idx
=
tl
.
load
(
IDX_ptr
+
b
)
valid
=
(
idx
>=
0
)
&
(
idx
<
N
)
if
valid
:
src_ptrs
=
SRC_ptr
+
b
*
stride_sb
+
h
*
stride_sh
dst_ptrs
=
DST_ptr
+
idx
*
stride_dn
+
h
*
stride_dh
tl
.
store
(
dst_ptrs
,
tl
.
load
(
src_ptrs
))
def
masked_index_copy_triton_dim0
(
dst
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
src
:
torch
.
Tensor
):
"""
In-place: dst.index_copy_(0, index, src) but masked:
- rows with index[b] < 0 or >= dst.shape[0] are skipped (no write).
Shapes:
dst: [N, H]
src: [B, H]
index: [B]
"""
assert
dst
.
ndim
==
2
and
src
.
ndim
==
2
and
index
.
ndim
==
1
N
,
H
=
dst
.
shape
B
,
Hs
=
src
.
shape
assert
Hs
==
H
and
index
.
numel
()
==
B
_masked_index_copy_kernel
[(
B
,
H
)](
dst
,
index
,
src
,
N
,
dst
.
stride
(
0
),
dst
.
stride
(
1
),
src
.
stride
(
0
),
src
.
stride
(
1
),
)
Prev
1
…
6
7
8
9
10
11
12
13
14
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment