Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b7160c6
Commit
2b7160c6
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.0
parent
fa718036
Changes
305
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
572 additions
and
0 deletions
+572
-0
vllm/kvprune_legacy_save/integration/config_adapter.py
vllm/kvprune_legacy_save/integration/config_adapter.py
+116
-0
vllm/kvprune_legacy_save/integration/v1_tp_runner.py
vllm/kvprune_legacy_save/integration/v1_tp_runner.py
+203
-0
vllm/kvprune_legacy_save/integration/vllm_model_access.py
vllm/kvprune_legacy_save/integration/vllm_model_access.py
+46
-0
vllm/kvprune_legacy_save/integration/weight_tie.py
vllm/kvprune_legacy_save/integration/weight_tie.py
+192
-0
vllm/kvprune_legacy_save/kv_cache/__init__.py
vllm/kvprune_legacy_save/kv_cache/__init__.py
+15
-0
No files found.
Too many changes to show.
To preserve performance only
305 of 305+
files are displayed.
Plain diff
Email patch
vllm/kvprune_legacy_save/integration/config_adapter.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Build :class:`vllm.kvprune.config.engine_config.LLMConfig` from :class:`VllmConfig`."""
from
__future__
import
annotations
import
os
from
pathlib
import
Path
from
vllm.config
import
VllmConfig
from
vllm.kvprune.config.engine_config
import
LLMConfig
,
KvpruneAttentionSchedule
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
_attention_schedule_from_env
()
->
KvpruneAttentionSchedule
:
"""Resolve :class:`KvpruneAttentionSchedule` from env.
Primary (``VLLM_KVPRUNE_ATTENTION_SCHEDULE``):
- ``fa_triton`` — FA prefill, Triton decode (default). Aliases: ``fa_prefill``,
``default``, empty.
- ``pdtriton`` — Triton prefill + Triton decode. Aliases: ``triton``,
``triton_prefill``, ``compactor_prefill``, ``pd_triton``.
- ``pdfa`` — FA prefill + FA decode (KV stores still Triton). Aliases:
``fa_full``, ``fa_both``.
Legacy: ``VLLM_KVPRUNE_ATTENTION_BACKEND`` maps ``flash``/``fa`` → ``fa_triton``,
``compactor``/``triton`` → ``pdtriton``.
"""
s
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_ATTENTION_SCHEDULE"
,
""
).
strip
().
lower
()
if
s
in
(
"fa_triton"
,
"fa_prefill"
,
"default"
,
""
):
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
if
s
in
(
"pdtriton"
,
"pd_triton"
,
"triton"
,
"triton_prefill"
,
"compactor_prefill"
):
return
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
if
s
in
(
"pdfa"
,
"fa_full"
,
"fa_both"
):
return
KvpruneAttentionSchedule
.
PDFA
if
s
:
logger
.
warning
(
"Unknown VLLM_KVPRUNE_ATTENTION_SCHEDULE=%r; using FA_PREFILL_TRITON_DECODE"
,
s
,
)
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
v
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_ATTENTION_BACKEND"
,
""
).
strip
().
lower
()
if
v
in
(
"flash"
,
"fa"
,
"flash_attention"
,
"flashattention"
):
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
if
v
in
(
"compactor"
,
"triton"
,
"compactor_triton"
,
""
):
return
KvpruneAttentionSchedule
.
TRITON_PREFILL_TRITON_DECODE
logger
.
warning
(
"Unknown VLLM_KVPRUNE_ATTENTION_BACKEND=%r; using FA_PREFILL_TRITON_DECODE"
,
v
)
return
KvpruneAttentionSchedule
.
FA_PREFILL_TRITON_DECODE
def
_compactor_kvcache_page_size
(
vllm_block_size
:
int
|
None
)
->
int
:
"""Tokens per physical KV page for compactor :class:`LLMConfig`.
vLLM ``block_size`` is often 16; compactor ``head_sparse_decode_attention`` requires
``PAGE_SIZE % 32 == 0`` (see ``kvprune/attention/sparse_decode_kernel.py``). Standalone
compactor-vllm defaults to 128. Round up to the next multiple of 32 when needed.
"""
if
vllm_block_size
is
None
:
return
128
bs
=
int
(
vllm_block_size
)
if
bs
<=
0
:
return
128
if
bs
%
32
==
0
:
return
bs
return
((
bs
+
31
)
//
32
)
*
32
def
vllm_config_to_llm_config
(
vc
:
VllmConfig
)
->
LLMConfig
:
"""Map vLLM engine config to compactor :class:`LLMConfig`."""
mc
=
vc
.
model_config
cc
=
vc
.
cache_config
pc
=
vc
.
parallel_config
sc
=
vc
.
scheduler_config
block_size
=
cc
.
block_size
if
block_size
is
None
:
block_size
=
16
max_num_seqs
=
getattr
(
sc
,
"max_num_seqs"
,
256
)
# Do **not** forward ``model_config.enforce_eager`` (v1) into compactor
# :class:`LLMConfig`. They are independent flags: v1 uses it only to skip
# *v1* ``capture_model()``; kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner`
# uses :attr:`LLMConfig.enforce_eager` only for *compactor* decode CUDA graphs.
# Shared-weight setup in ``compactor_shared`` defaults compactor to eager decode;
# see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` (default try graphs) /
# ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``.
# Local checkpoint directory: forward so compactor skips redundant Hub fetches.
_model_s
=
str
(
mc
.
model
)
_path
:
str
|
None
=
None
try
:
if
_model_s
and
Path
(
_model_s
).
is_dir
()
and
(
Path
(
_model_s
)
/
"config.json"
).
is_file
():
_path
=
str
(
Path
(
_model_s
).
resolve
())
except
OSError
:
pass
return
LLMConfig
(
model
=
_model_s
,
path
=
_path
,
nccl_port
=
1218
,
max_num_seqs
=
max_num_seqs
,
max_model_len
=
mc
.
max_model_len
,
gpu_memory_utilization
=
cc
.
gpu_memory_utilization
,
tensor_parallel_size
=
pc
.
tensor_parallel_size
,
enforce_eager
=
False
,
hf_config
=
mc
.
hf_config
,
eos
=-
1
,
eos_token_ids
=
None
,
kvcache_page_size
=
_compactor_kvcache_page_size
(
block_size
),
leverage_sketch_size
=
48
,
attention_schedule
=
_attention_schedule_from_env
(),
attention_backend
=
None
,
)
vllm/kvprune_legacy_save/integration/v1_tp_runner.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""TP>1: one kvprune :class:`~vllm.kvprune.core.model_runner.ModelRunner` per vLLM worker.
Invoked via v1 ``collective_rpc("kvprune_v1_compressed_generate", ...)`` so every tensor-
parallel rank participates in the same compactor forward/broadcast sequence as the
standalone multi-process compactor.
Compactor decode CUDA graphs (when not ``enforce_eager``) capture the full decode step
including ``compute_logits``. To force eager on embedded TP workers, set
``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` or ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``.
Peer/master session boundaries use TP-group ``broadcast``/``barrier`` (see
``ModelRunner.maybe_release_peers``), not ``multiprocessing.Event`` — RPC payloads must
be picklable across worker processes.
"""
from
__future__
import
annotations
import
os
from
typing
import
Any
import
torch
import
torch.nn
as
nn
from
vllm.kvprune.compression.compression_config
import
(
BatchCompressionParams
,
SequenceCompressionParams
,
)
from
vllm.kvprune.config.sampling_params
import
SamplingParams
as
CompactorSamplingParams
from
vllm.kvprune.core.compression_bridge
import
(
compression_method_id_to_enum
,
compression_method_str_to_id
,
)
from
vllm.kvprune.core.model_runner
import
ModelRunner
from
vllm.kvprune.integration.config_adapter
import
vllm_config_to_llm_config
from
vllm.kvprune.utils.kv_dist
import
barrier_sync
from
vllm.kvprune.integration.weight_tie
import
(
delegate_kvprune_compute_logits_to_vllm
,
delegate_kvprune_embed_tokens_to_vllm
,
tie_kvprune_rope_buffers_from_vllm
,
tie_kvprune_weights_from_vllm
,
)
from
vllm.kvprune.models
import
MODEL_REGISTRY
from
vllm.kvprune.utils.sequence
import
Sequence
_ATTR
=
"_kvprune_tp_embedded_runner"
def
_apply_compactor_env_overrides
(
cfg
:
Any
)
->
None
:
"""Match :func:`~vllm.kvprune.integration.compactor_shared.create_compactor_engine_with_shared_weights` caps."""
_cap
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_MAX_NUM_SEQS"
,
"32"
).
strip
()
if
_cap
:
lim
=
int
(
_cap
)
if
lim
>
0
:
cfg
.
max_num_seqs
=
min
(
cfg
.
max_num_seqs
,
lim
)
_ce
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER"
,
""
).
strip
().
lower
()
if
_ce
in
(
"1"
,
"true"
,
"yes"
):
cfg
.
enforce_eager
=
True
elif
_ce
in
(
"0"
,
"false"
,
"no"
):
cfg
.
enforce_eager
=
False
else
:
_dg
=
os
.
environ
.
get
(
"VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH"
,
"1"
).
strip
().
lower
()
cfg
.
enforce_eager
=
_dg
in
(
"0"
,
"false"
,
"no"
)
def
_build_sequences
(
payload
:
dict
[
str
,
Any
])
->
list
[
Sequence
]:
prompt_ids
:
list
[
list
[
int
]]
=
payload
[
"prompt_token_ids"
]
sps
:
list
[
dict
[
str
,
Any
]]
=
payload
[
"sampling_params"
]
cps
:
list
[
dict
[
str
,
Any
]]
=
payload
[
"compression_params"
]
seqs
:
list
[
Sequence
]
=
[]
for
i
,
ids
in
enumerate
(
prompt_ids
):
sp
=
CompactorSamplingParams
(
temperature
=
float
(
sps
[
i
][
"temperature"
]),
max_new_tokens
=
int
(
sps
[
i
][
"max_new_tokens"
]),
)
cp
=
SequenceCompressionParams
(
compression_ratio
=
float
(
cps
[
i
][
"compression_ratio"
]),
protected_first_tokens
=
int
(
cps
[
i
].
get
(
"protected_first_tokens"
,
16
)),
protected_last_tokens
=
int
(
cps
[
i
].
get
(
"protected_last_tokens"
,
64
)),
)
if
cp
.
protected_first_tokens
+
cp
.
protected_last_tokens
>=
len
(
ids
):
cp
.
compression_ratio
=
1.0
seqs
.
append
(
Sequence
(
prompt_token_ids
=
list
(
ids
),
sampling_params
=
sp
,
compression_params
=
cp
,
)
)
return
seqs
def
_batch_compression_from_payload
(
payload
:
dict
[
str
,
Any
])
->
BatchCompressionParams
:
cps
=
payload
[
"compression_params"
]
for
c
in
cps
:
if
float
(
c
[
"compression_ratio"
])
<
1.0
:
mid
=
compression_method_str_to_id
(
str
(
c
.
get
(
"compression_method"
,
"none"
)))
return
BatchCompressionParams
(
compression_method
=
compression_method_id_to_enum
(
mid
)
)
return
BatchCompressionParams
()
def
_get_or_create_runner
(
worker
:
Any
,
payload
:
dict
[
str
,
Any
])
->
ModelRunner
:
existing
=
getattr
(
worker
,
_ATTR
,
None
)
if
existing
is
not
None
:
return
existing
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
vc
=
worker
.
vllm_config
pc
=
vc
.
parallel_config
if
pc
.
pipeline_parallel_size
!=
1
or
pc
.
data_parallel_size
!=
1
:
raise
NotImplementedError
(
"KV-prune TP compressed generate requires pipeline_parallel_size=1 and "
f
"data_parallel_size=1; got PP=
{
pc
.
pipeline_parallel_size
}
, "
f
"DP=
{
pc
.
data_parallel_size
}
."
)
tp_ws
=
get_tensor_model_parallel_world_size
()
if
tp_ws
!=
pc
.
tensor_parallel_size
:
raise
RuntimeError
(
f
"parallel_state TP world size
{
tp_ws
}
!= config.tensor_parallel_size "
f
"
{
pc
.
tensor_parallel_size
}
"
)
hf
=
vc
.
model_config
.
hf_config
model_type
=
getattr
(
hf
,
"model_type"
,
None
)
if
model_type
not
in
MODEL_REGISTRY
:
raise
ValueError
(
f
"KV-prune TP path: unsupported model_type=
{
model_type
!
r
}
; "
f
"registry has
{
sorted
(
MODEL_REGISTRY
)
}
"
)
cfg
=
vllm_config_to_llm_config
(
vc
)
eos_ids
=
payload
[
"eos_token_ids"
]
cfg
.
eos_token_ids
=
sorted
({
int
(
x
)
for
x
in
eos_ids
})
cfg
.
eos
=
int
(
cfg
.
eos_token_ids
[
0
])
_apply_compactor_env_overrides
(
cfg
)
vllm_model
=
worker
.
get_model
()
kv_model
:
nn
.
Module
=
MODEL_REGISTRY
[
model_type
](
hf
)
tie_kvprune_weights_from_vllm
(
vllm_model
,
kv_model
)
dev
=
next
(
vllm_model
.
parameters
()).
device
dtype
=
next
(
vllm_model
.
parameters
()).
dtype
kv_model
.
to
(
device
=
dev
,
dtype
=
dtype
)
tie_kvprune_rope_buffers_from_vllm
(
vllm_model
,
kv_model
)
delegate_kvprune_embed_tokens_to_vllm
(
vllm_model
,
kv_model
)
delegate_kvprune_compute_logits_to_vllm
(
vllm_model
,
kv_model
)
tp_rank
=
get_tensor_model_parallel_rank
()
device
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
if
tp_rank
==
0
:
runner
=
ModelRunner
(
cfg
,
rank
=
0
,
peer_events
=
[],
external_model
=
kv_model
,
embedded_in_vllm_worker
=
True
,
device
=
device
,
)
else
:
runner
=
ModelRunner
(
cfg
,
rank
=
tp_rank
,
batch_ready
=
None
,
external_model
=
kv_model
,
embedded_in_vllm_worker
=
True
,
device
=
device
,
)
setattr
(
worker
,
_ATTR
,
runner
)
return
runner
def
run_kvprune_tp_compressed_generate
(
worker
:
Any
,
payload
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Execute one compressed generation session on this worker (all TP ranks)."""
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
tp_rank
=
get_tensor_model_parallel_rank
()
runner
=
_get_or_create_runner
(
worker
,
payload
)
sequences
=
_build_sequences
(
payload
)
batch_c
=
_batch_compression_from_payload
(
payload
)
barrier_sync
(
use_tp_group
=
True
)
if
tp_rank
==
0
:
runner
.
generate
(
sequences
,
batch_c
)
return
{
"tensor_parallel_rank"
:
0
,
"prompt_token_ids"
:
[
list
(
s
.
prompt_token_ids
)
for
s
in
sequences
],
"completion_token_ids"
:
[
list
(
s
.
completion_token_ids
)
for
s
in
sequences
],
}
runner
.
run_peer_session
()
return
{
"tensor_parallel_rank"
:
int
(
tp_rank
),
"ok"
:
True
}
vllm/kvprune_legacy_save/integration/vllm_model_access.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Access the in-process vLLM model weights for compactor weight sharing."""
from
__future__
import
annotations
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
extract_vllm_causal_lm
(
llm
:
object
)
->
nn
.
Module
:
"""Return the root ``nn.Module`` holding transformer + lm_head from a v1 ``LLM``.
Requires ``LLMEngine`` to have been constructed with ``multiprocess_mode=False``
so ``model_executor`` lives in-process (set ``VLLM_ENABLE_V1_MULTIPROCESSING=0``).
"""
llm_engine
=
getattr
(
llm
,
"llm_engine"
,
None
)
if
llm_engine
is
None
:
raise
RuntimeError
(
"Expected an object with a ``llm_engine`` attribute (e.g. ``vllm.LLM``)."
)
ex
=
getattr
(
llm_engine
,
"model_executor"
,
None
)
if
ex
is
None
:
raise
RuntimeError
(
"model_executor is unavailable (multiprocess engine mode). "
"Set environment variable VLLM_ENABLE_V1_MULTIPROCESSING=0 for "
"in-process weight sharing."
)
driver
=
getattr
(
ex
,
"driver_worker"
,
None
)
if
driver
is
None
:
raise
RuntimeError
(
"Executor has no driver_worker (unexpected executor type for weight sharing)."
)
worker
=
getattr
(
driver
,
"worker"
,
None
)
if
worker
is
None
:
raise
RuntimeError
(
"Worker wrapper has no worker loaded."
)
get_model
=
getattr
(
worker
,
"get_model"
,
None
)
if
not
callable
(
get_model
):
raise
RuntimeError
(
"Worker does not expose get_model()."
)
return
get_model
()
vllm/kvprune_legacy_save/integration/weight_tie.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Share vLLM parameter storage with compactor ``MODEL_REGISTRY`` models (TP=1)."""
from
__future__
import
annotations
import
types
import
torch
import
torch.nn
as
nn
from
vllm.kvprune.utils.context
import
get_context
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
tie_kvprune_weights_from_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
*
,
strict
:
bool
=
True
,
)
->
int
:
"""Point compactor parameters to the same tensors as vLLM where names match.
Returns the number of parameters tied. Requires identical parameter names
and shapes for overlapping weights (typical when both stacks mirror HF
naming for the same architecture).
Args:
vllm_model: Model returned by ``worker.get_model()`` (e.g. ``Qwen3ForCausalLM``).
kvprune_model: Instance from ``vllm.kvprune.models.MODEL_REGISTRY``.
strict: If True, raise when any ``kvprune`` parameter name is missing from
``vllm_model`` or shapes differ.
"""
vd
=
dict
(
vllm_model
.
named_parameters
())
kd
=
dict
(
kvprune_model
.
named_parameters
())
tied
=
0
for
name
,
kp
in
kd
.
items
():
if
name
not
in
vd
:
if
strict
:
raise
ValueError
(
f
"kvprune parameter
{
name
!
r
}
not found in vLLM model; "
"architecture/layout may differ (disable strict tying only "
"for expert debugging)."
)
continue
vp
=
vd
[
name
]
if
vp
.
shape
!=
kp
.
shape
:
raise
ValueError
(
f
"Shape mismatch for
{
name
}
: vllm
{
vp
.
shape
}
vs kvprune
{
kp
.
shape
}
"
)
kp
.
data
=
vp
.
data
tied
+=
1
if
tied
==
0
:
raise
ValueError
(
"No parameters were tied — check that vLLM and kvprune model types match "
"and use the same state_dict names."
)
logger
.
info
(
"Tied %d parameters from vLLM into compactor model (shared storage)."
,
tied
)
return
tied
def
tie_kvprune_rope_buffers_from_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
)
->
int
:
"""Copy RoPE ``cos_sin_cache`` buffers from vLLM into kvprune.
:func:`tie_kvprune_weights_from_vllm` only aliases :class:`~torch.nn.Parameter`
tensors. RoPE tables live in buffers; kvprune's simplified ``RotaryEmbedding``
can disagree with vLLM's ``rope_parameters`` (YaRN, etc.). Copying
``cos_sin_cache`` after ``.to(device, dtype)`` keeps Q/K rotation aligned with
the main model.
kvprune uses layout ``[max_len, 1, rotary_dim]``; vLLM uses ``[max_len,
rotary_dim]``. The singleton dim is filled via ``unsqueeze(1)`` on the vLLM
tensor when copying.
"""
vd
=
dict
(
vllm_model
.
named_buffers
())
copied
=
0
for
name
,
kb
in
kvprune_model
.
named_buffers
():
if
"cos_sin_cache"
not
in
name
:
continue
if
name
not
in
vd
:
logger
.
warning
(
"kvprune RoPE buffer %r not found in vLLM; leaving kvprune cache"
,
name
,
)
continue
vb
=
vd
[
name
]
if
vb
.
shape
==
kb
.
shape
:
kb
.
copy_
(
vb
)
copied
+=
1
elif
kb
.
dim
()
==
3
and
vb
.
dim
()
==
2
:
if
(
kb
.
shape
[
0
]
!=
vb
.
shape
[
0
]
or
kb
.
shape
[
2
]
!=
vb
.
shape
[
1
]
or
kb
.
shape
[
1
]
!=
1
):
raise
ValueError
(
f
"cos_sin_cache shape mismatch for
{
name
!
r
}
: "
f
"vLLM
{
tuple
(
vb
.
shape
)
}
vs kvprune
{
tuple
(
kb
.
shape
)
}
"
)
kb
.
copy_
(
vb
.
unsqueeze
(
1
))
copied
+=
1
else
:
raise
ValueError
(
f
"Unsupported cos_sin_cache layout for
{
name
!
r
}
: "
f
"vLLM
{
tuple
(
vb
.
shape
)
}
vs kvprune
{
tuple
(
kb
.
shape
)
}
"
)
if
copied
:
logger
.
info
(
"Copied %d RoPE cos_sin_cache buffer(s) from vLLM into kvprune model."
,
copied
,
)
return
copied
def
delegate_kvprune_embed_tokens_to_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
)
->
bool
:
"""Use vLLM's ``model.embed_tokens`` forward for kvprune (TP-safe token→shard mapping).
Even with tied weights, kvprune's simplified contiguous
``VocabParallelEmbedding`` (``vocab_start = rank * partition``) can disagree with
vLLM's padded vocabulary and org/added shard ranges, producing invalid indices for
``F.embedding`` on non-zero TP ranks (``index_copy_`` / device-side assert).
Delegating the forward to vLLM's embedding module keeps masks and indices aligned
with the main model while parameters remain shared storage.
"""
if
not
hasattr
(
vllm_model
,
"model"
)
or
not
hasattr
(
kvprune_model
,
"model"
):
return
False
vm
=
getattr
(
vllm_model
.
model
,
"embed_tokens"
,
None
)
km
=
getattr
(
kvprune_model
.
model
,
"embed_tokens"
,
None
)
if
vm
is
None
or
km
is
None
:
logger
.
warning
(
"delegate_kvprune_embed_tokens_to_vllm: embed_tokens missing; skipped"
)
return
False
def
_forward
(
_self_unused
:
nn
.
Module
,
x
):
return
vm
(
x
)
km
.
forward
=
types
.
MethodType
(
_forward
,
km
)
logger
.
info
(
"kvprune model.embed_tokens forward delegated to vLLM (correct vocab-parallel masks)."
)
return
True
def
delegate_kvprune_compute_logits_to_vllm
(
vllm_model
:
nn
.
Module
,
kvprune_model
:
nn
.
Module
,
)
->
bool
:
"""Route ``kvprune_model.compute_logits`` through vLLM's ``compute_logits``.
Standalone compactor used :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`
with ``F.linear`` + TP gather. vLLM applies :class:`~vllm.model_executor.layers.logits_processor.LogitsProcessor`
(gather/all-gather, padded-vocab trim, quant hooks). Mismatch here commonly
produces garbage token distributions while the rest of the stack looks fine.
After weight tying, ``vllm_model.compute_logits(hidden)`` uses the same lm_head
storage as kvprune; only the *application* path matches production vLLM.
"""
if
not
callable
(
getattr
(
vllm_model
,
"compute_logits"
,
None
)):
logger
.
warning
(
"delegate_kvprune_compute_logits_to_vllm: vLLM model has no compute_logits; skipped"
)
return
False
def
_compute_logits
(
_self
:
nn
.
Module
,
hidden_states
):
# Match kvprune :class:`~vllm.kvprune.layers.embed_head.ParallelLMHead`:
# prefill logits are for the **last** token of each packed sequence only.
context
=
get_context
()
if
context
.
is_prefill
and
context
.
cu_seqlens_q
is
not
None
:
cuq
=
context
.
cu_seqlens_q
last_indices
=
(
cuq
[
1
:]
-
1
).
to
(
torch
.
long
)
n_tok
=
hidden_states
.
shape
[
0
]
if
n_tok
>
0
:
last_indices
=
last_indices
.
clamp
(
min
=
0
,
max
=
n_tok
-
1
)
hidden_states
=
hidden_states
[
last_indices
].
contiguous
()
# vLLM lm_head + gather expect contiguous activations; non-contiguous views have
# caused garbage logits under TP in edge cases.
hidden_states
=
hidden_states
.
contiguous
()
logits
=
vllm_model
.
compute_logits
(
hidden_states
)
return
logits
kvprune_model
.
compute_logits
=
types
.
MethodType
(
_compute_logits
,
kvprune_model
)
return
True
vllm/kvprune_legacy_save/kv_cache/__init__.py
0 → 100644
View file @
2b7160c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Paged KV cache helpers and Triton KV store."""
from
vllm.kvprune.kv_cache.store_kv_cache
import
(
decode_store_kv
,
prefill_store_all_kv
,
prefill_store_topk_kv
,
)
__all__
=
[
"decode_store_kv"
,
"prefill_store_all_kv"
,
"prefill_store_topk_kv"
,
]
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment