Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
kvpress-SnapKV-Qwen3-8B_pytorch
Commits
c16d506e
Commit
c16d506e
authored
Mar 10, 2026
by
chenzk
Browse files
v1.0
parents
Changes
137
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3027 additions
and
0 deletions
+3027
-0
kvpress/kvpress/pipeline.py
kvpress/kvpress/pipeline.py
+327
-0
kvpress/kvpress/presses/__init__.py
kvpress/kvpress/presses/__init__.py
+2
-0
kvpress/kvpress/presses/adakv_press.py
kvpress/kvpress/presses/adakv_press.py
+78
-0
kvpress/kvpress/presses/base_press.py
kvpress/kvpress/presses/base_press.py
+201
-0
kvpress/kvpress/presses/block_press.py
kvpress/kvpress/presses/block_press.py
+98
-0
kvpress/kvpress/presses/chunk_press.py
kvpress/kvpress/presses/chunk_press.py
+87
-0
kvpress/kvpress/presses/chunkkv_press.py
kvpress/kvpress/presses/chunkkv_press.py
+125
-0
kvpress/kvpress/presses/compactor_press.py
kvpress/kvpress/presses/compactor_press.py
+122
-0
kvpress/kvpress/presses/composed_press.py
kvpress/kvpress/presses/composed_press.py
+62
-0
kvpress/kvpress/presses/criticalkv_press.py
kvpress/kvpress/presses/criticalkv_press.py
+194
-0
kvpress/kvpress/presses/cur_press.py
kvpress/kvpress/presses/cur_press.py
+67
-0
kvpress/kvpress/presses/decoding_press.py
kvpress/kvpress/presses/decoding_press.py
+224
-0
kvpress/kvpress/presses/dms_press.py
kvpress/kvpress/presses/dms_press.py
+127
-0
kvpress/kvpress/presses/duo_attention_press.py
kvpress/kvpress/presses/duo_attention_press.py
+210
-0
kvpress/kvpress/presses/expected_attention_press.py
kvpress/kvpress/presses/expected_attention_press.py
+165
-0
kvpress/kvpress/presses/expected_attention_with_stats.py
kvpress/kvpress/presses/expected_attention_with_stats.py
+289
-0
kvpress/kvpress/presses/fastkvzip_press.py
kvpress/kvpress/presses/fastkvzip_press.py
+287
-0
kvpress/kvpress/presses/finch_press.py
kvpress/kvpress/presses/finch_press.py
+166
-0
kvpress/kvpress/presses/key_rerotation_press.py
kvpress/kvpress/presses/key_rerotation_press.py
+150
-0
kvpress/kvpress/presses/keydiff_press.py
kvpress/kvpress/presses/keydiff_press.py
+46
-0
No files found.
kvpress/kvpress/pipeline.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
contextlib
import
logging
from
typing
import
Optional
import
torch
from
transformers
import
AutoModelForCausalLM
,
Cache
,
DynamicCache
,
Pipeline
,
QuantizedCache
from
transformers.pipelines
import
PIPELINE_REGISTRY
from
transformers.pipelines.base
import
GenericTensor
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.decoding_press
import
DecodingPress
from
kvpress.presses.dms_press
import
DMSPress
from
kvpress.presses.finch_press
import
FinchPress
from
kvpress.presses.key_rerotation_press
import
KeyRerotationPress
from
kvpress.presses.prefill_decoding_press
import
PrefillDecodingPress
logger
=
logging
.
getLogger
(
__name__
)
class
KVPressTextGenerationPipeline
(
Pipeline
):
"""
Pipeline for key-value cache compression in causal language models.
Enables efficient processing of long contexts by applying KV cache compression
during pre-filling, then generating answers using greedy decoding.
Example:
```python
pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
press = SnapKVPress(compression_ratio=0.5)
result = pipeline(context="Long text...", question="A question about the long context.", press=press)
```
"""
def
_sanitize_parameters
(
self
,
question
:
Optional
[
str
]
=
None
,
questions
:
Optional
[
list
[
str
]]
=
None
,
answer_prefix
:
Optional
[
str
]
=
None
,
press
:
Optional
[
BasePress
]
=
None
,
max_new_tokens
:
int
=
50
,
max_context_length
:
Optional
[
int
]
=
None
,
enable_thinking
:
bool
=
False
,
cache
:
Optional
[
Cache
]
=
None
,
**
kwargs
,
):
"""
Sanitize the input parameters for the pipeline.
The user can either provide a single question or a list of questions to be asked about the context.
Parameters
----------
question : str, optional
The question to be asked about the context. Exclusive with `questions`.
questions : list[str], optional
A list of questions to be asked about the context. Exclusive with `question`.
answer_prefix : str, optional
The prefix to be added to the generated answer.
press : BasePress, optional
The key-value cache compression method to apply during pre-filling.
Accepts any KVPress compression method (SnapKVPress, KnormPress,
ExpectedAttentionPress, BlockPress, AdaKVPress, ComposedPress, etc.).
If None, no compression is applied.
max_new_tokens : int, optional
The maximum number of new tokens to generate for each answer.
max_context_length : int, optional
The maximum number of tokens in the context. By default will use the maximum length supported by the model.
enable_thinking: bool = False,
Whether to enable thinking in the chat template (chat template must support this argument)
cache : Cache, optional
The cache to use for the forward pass. Defaults to None (DynamicCache).
**kwargs : dict
Additional keyword arguments, currently ignored.
Returns
-------
Tuple[dict, dict, dict]
A tuple containing three dictionaries:
- preprocess_kwargs: The keyword arguments for the preprocess function.
- forward_kwargs: The keyword arguments for the forward function.
- postprocess_kwargs: The keyword arguments for the postprocess function.
"""
answer_prefix
=
answer_prefix
or
""
postprocess_kwargs
=
{
"single_question"
:
questions
is
None
}
assert
question
is
None
or
questions
is
None
,
"Either question or questions should be provided, not both."
questions
=
questions
or
([
question
]
if
question
else
[
""
])
if
max_context_length
is
None
:
max_context_length
=
min
(
self
.
tokenizer
.
model_max_length
,
int
(
1e10
))
# 1e10 to avoid overflow
preprocess_kwargs
=
{
"questions"
:
questions
,
"answer_prefix"
:
answer_prefix
,
"max_context_length"
:
max_context_length
,
"enable_thinking"
:
enable_thinking
,
}
forward_kwargs
=
{
"press"
:
press
,
"max_new_tokens"
:
max_new_tokens
,
"cache"
:
cache
}
return
preprocess_kwargs
,
forward_kwargs
,
postprocess_kwargs
def
preprocess
(
self
,
context
:
str
,
questions
:
list
[
str
],
answer_prefix
:
str
,
max_context_length
:
int
,
enable_thinking
:
bool
=
False
,
):
"""
Apply chat template and tokenize the context and questions.
Prepares input text for KV cache compression and generation by applying
appropriate chat templates and tokenizing. Handles models with and without
chat templates.
Parameters
----------
context : str
Long context text to be compressed using the press method.
questions : list[str]
Questions to be asked about the context.
answer_prefix : str
Optional prefix for generated answers.
max_context_length : int
Maximum tokens allowed in context (truncated if exceeded).
enable_thinking : bool
Whether to enable thinking in the chat template (chat template must support this argument)
Returns
-------
dict[str, GenericTensor]
Dictionary with "context_ids" and "questions_ids" tensors.
"""
# Apply chat template if available
if
self
.
tokenizer
.
chat_template
is
None
:
bos_token
=
getattr
(
self
.
tokenizer
,
"bos_token"
,
""
)
context
=
bos_token
+
context
question_suffix
=
"
\n
"
# to separate the question from the answer
else
:
separator
=
"#"
*
(
len
(
context
)
+
10
)
context
=
self
.
tokenizer
.
apply_chat_template
(
[{
"role"
:
"user"
,
"content"
:
context
+
separator
}],
add_generation_prompt
=
True
,
tokenize
=
False
,
enable_thinking
=
enable_thinking
,
)
context
,
question_suffix
=
context
.
split
(
separator
)
# Add question_suffix and answer prefix
# e.g. for llama3.1, question_suffix="<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
questions
=
[
question
+
question_suffix
+
answer_prefix
for
question
in
questions
]
# Tokenize the context and questions
context_ids
=
self
.
tokenizer
.
encode
(
context
,
return_tensors
=
"pt"
,
add_special_tokens
=
False
)
question_ids
=
[
self
.
tokenizer
.
encode
(
question
,
return_tensors
=
"pt"
,
add_special_tokens
=
False
)
for
question
in
questions
]
# Truncate context
if
context_ids
.
shape
[
1
]
>
max_context_length
:
logger
.
warning
(
f
"Context length has been truncated from
{
context_ids
.
shape
[
1
]
}
to
{
max_context_length
}
tokens."
)
context_ids
=
context_ids
[:,
:
max_context_length
]
return
{
"context_ids"
:
context_ids
,
"questions_ids"
:
question_ids
}
def
_forward
(
self
,
input_tensors
:
dict
[
str
,
GenericTensor
],
max_new_tokens
:
int
=
50
,
press
:
Optional
[
BasePress
]
=
None
,
cache
:
Optional
[
Cache
]
=
None
,
):
"""
Execute KV cache compression and text generation pipeline.
Performs context compression using the press method during pre-filling,
then generates answers using greedy decoding.
Parameters
----------
input_tensors : dict[str, GenericTensor]
Tokenized inputs with "context_ids" and "questions_ids".
max_new_tokens : int, default=50
Maximum tokens to generate for each answer.
press : BasePress, optional
Compression method for context pre-filling. If None, no compression.
cache : Cache, optional
Cache object for forward pass. If None, creates new DynamicCache.
Returns
-------
list[str]
Generated answers for each input question.
"""
if
isinstance
(
press
,
(
DecodingPress
,
PrefillDecodingPress
))
and
len
(
input_tensors
[
"questions_ids"
])
>
1
:
raise
ValueError
(
"DecodingPress is not compatible with multiple questions. Please specify a single question."
)
context_ids
=
input_tensors
[
"context_ids"
].
to
(
self
.
model
.
device
)
context_length
=
context_ids
.
shape
[
1
]
# Prefilling using the press on the context
if
cache
is
None
:
cache
=
DynamicCache
()
# We only perform prefill compression if the press is a prefill press
perform_prefill_compression
=
press
is
not
None
and
not
isinstance
(
press
,
DecodingPress
)
with
press
(
self
.
model
)
if
perform_prefill_compression
else
contextlib
.
nullcontext
():
# We run the model without the lm head for pre-filling.
self
.
model
.
model
(
input_ids
=
context_ids
,
past_key_values
=
cache
,
)
logger
.
debug
(
f
"Context Length:
{
context_length
}
"
)
logger
.
debug
(
f
"Compressed Context Length:
{
cache
.
get_seq_length
()
}
"
)
# We only perform decoding compression if the press is a decoding or prefill decoding press
perform_decoding_compression
=
press
is
not
None
and
isinstance
(
press
,
(
DecodingPress
,
PrefillDecodingPress
))
if
isinstance
(
press
,
DMSPress
):
perform_decoding_compression
=
press
.
decoding
with
press
(
self
.
model
)
if
perform_decoding_compression
else
contextlib
.
nullcontext
():
# Greedy decoding for each question
answers
=
[]
for
question_ids
in
input_tensors
[
"questions_ids"
]:
if
isinstance
(
press
,
KeyRerotationPress
)
or
(
isinstance
(
press
,
FinchPress
)
and
press
.
rerotate_keys
):
context_length
=
cache
.
get_seq_length
()
cache_seq_lengths
=
[
cache
.
get_seq_length
(
layer_idx
)
for
layer_idx
in
range
(
len
(
cache
))]
answer
=
self
.
generate_answer
(
question_ids
=
question_ids
.
to
(
self
.
model
.
device
),
cache
=
cache
,
context_length
=
context_length
,
max_new_tokens
=
max_new_tokens
,
)
self
.
_remove_answer_from_cache
(
cache
,
cache_seq_lengths
)
answers
.
append
(
answer
)
return
answers
def
_remove_answer_from_cache
(
self
,
cache
:
Cache
,
cache_seq_lengths
:
list
[
int
]):
for
layer_idx
,
sequence_length
in
enumerate
(
cache_seq_lengths
):
cache
.
layers
[
layer_idx
].
keys
=
cache
.
layers
[
layer_idx
].
keys
[:,
:,
:
sequence_length
]
cache
.
layers
[
layer_idx
].
values
=
cache
.
layers
[
layer_idx
].
values
[:,
:,
:
sequence_length
]
if
isinstance
(
cache
,
QuantizedCache
):
for
layer_idx
,
sequence_length
in
enumerate
(
cache_seq_lengths
):
cache
.
layers
[
layer_idx
].
_quantized_keys
=
cache
.
layers
[
layer_idx
].
_quantized_keys
[
:,
:,
:
sequence_length
]
cache
.
layers
[
layer_idx
].
_quantized_values
=
cache
.
layers
[
layer_idx
].
_quantized_values
[
:,
:,
:
sequence_length
]
def
generate_answer
(
self
,
question_ids
:
torch
.
Tensor
,
cache
:
Cache
,
context_length
:
int
,
max_new_tokens
:
int
)
->
str
:
"""
Generate an answer to a question using greedy decoding.
Parameters
----------
question_ids : torch.Tensor
The tokenized question.
cache : Cache
The compressed key-value cache.
context_length : int
The length of the context.
max_new_tokens : int
The maximum number of new tokens to generate.
Returns
-------
str
The generated answer.
"""
position_ids
=
torch
.
arange
(
context_length
,
context_length
+
question_ids
.
shape
[
1
],
device
=
self
.
model
.
device
).
unsqueeze
(
0
)
# if the user doesn't provide a question, skip forward pass
outputs
=
self
.
model
(
input_ids
=
question_ids
.
to
(
self
.
model
.
device
),
past_key_values
=
cache
,
position_ids
=
position_ids
,
num_logits_to_keep
=
1
,
)
position_ids
=
position_ids
[:,
-
1
:]
+
1
generated_ids
=
[
outputs
.
logits
[
0
,
-
1
].
argmax
()]
should_stop_token_ids
=
self
.
model
.
generation_config
.
eos_token_id
if
not
isinstance
(
should_stop_token_ids
,
list
):
should_stop_token_ids
=
[
should_stop_token_ids
]
for
i
in
range
(
max_new_tokens
-
1
):
outputs
=
self
.
model
(
input_ids
=
generated_ids
[
-
1
].
unsqueeze
(
0
).
unsqueeze
(
0
),
past_key_values
=
cache
,
position_ids
=
position_ids
+
i
,
)
new_id
=
outputs
.
logits
[
0
,
-
1
].
argmax
()
generated_ids
.
append
(
new_id
)
if
new_id
.
item
()
in
should_stop_token_ids
:
break
answer
=
str
(
self
.
tokenizer
.
decode
(
torch
.
stack
(
generated_ids
),
skip_special_tokens
=
True
))
return
answer
def
postprocess
(
self
,
model_outputs
,
single_question
):
if
single_question
:
return
{
"answer"
:
model_outputs
[
0
]}
return
{
"answers"
:
model_outputs
}
PIPELINE_REGISTRY
.
register_pipeline
(
"kv-press-text-generation"
,
pipeline_class
=
KVPressTextGenerationPipeline
,
pt_model
=
AutoModelForCausalLM
,
)
kvpress/kvpress/presses/__init__.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
kvpress/kvpress/presses/adakv_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
AdaKVPress
(
BasePress
):
"""
AdaKV: Adaptive head-wise KV cache compression.
Performs head-specific compression by selecting top-k tokens across all heads
based on importance scores. Applies safeguards to ensure each head retains
a minimum fraction of tokens.
Based on AdaKV (https://arxiv.org/abs/2407.11550).
Parameters
----------
press : ScorerPress
AdaKVPress and ObservedAttention are currently not supported.
alpha_safeguard : float, default=0.20
Minimum fraction of KV pairs that each head must retain.
Ensures no attention head is compressed too aggressively. Even if tokens
receive low global importance scores, each head retains at least this
fraction of its original tokens.
"""
press
:
ScorerPress
alpha_safeguard
:
float
=
0.20
def
__post_init__
(
self
):
assert
isinstance
(
self
.
press
,
ScorerPress
),
"AdaKVPress requires a ScorerPress as input"
assert
0
<=
self
.
alpha_safeguard
<=
1
,
"alpha_safeguard should be in [0, 1]"
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
def
compression_ratio
(
self
):
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
def
compress
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
if
self
.
compression_ratio
==
0
:
return
keys
,
values
assert
module
.
config
.
_attn_implementation
!=
"eager"
,
"eager mode not supported"
# Compute scores
scores
=
self
.
press
.
score
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
bsz
,
num_key_value_heads
,
k_len
=
scores
.
shape
# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
n_kept
=
int
(
k_len
*
(
1
-
self
.
compression_ratio
))
# ScorerPress definition
n_safe
=
int
(
n_kept
*
self
.
alpha_safeguard
)
top_indices
=
torch
.
topk
(
scores
,
n_safe
,
dim
=-
1
).
indices
scores
.
scatter_
(
-
1
,
top_indices
,
torch
.
finfo
(
scores
.
dtype
).
max
)
# Compute bottom-k across heads
n_pruned
=
num_key_value_heads
*
(
k_len
-
n_kept
)
indices
=
torch
.
topk
(
-
scores
.
reshape
(
bsz
,
-
1
),
n_pruned
,
dim
=
1
).
indices
.
flatten
()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices
=
torch
.
arange
(
bsz
).
repeat_interleave
(
n_pruned
)
head_indices
=
indices
//
k_len
seq_indices
=
indices
%
k_len
module
.
masked_key_indices
=
(
batch_indices
,
head_indices
,
seq_indices
)
return
keys
,
values
kvpress/kvpress/presses/base_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Generator
import
torch
from
torch
import
nn
from
transformers
import
(
Gemma3ForConditionalGeneration
,
LlamaForCausalLM
,
MistralForCausalLM
,
Phi3ForCausalLM
,
PreTrainedModel
,
QuantizedCache
,
Qwen2ForCausalLM
,
Qwen3ForCausalLM
,
)
from
kvpress.utils
import
extract_keys_and_values
logger
=
logging
.
getLogger
(
__name__
)
SUPPORTED_MODELS
=
(
LlamaForCausalLM
,
MistralForCausalLM
,
Phi3ForCausalLM
,
Qwen2ForCausalLM
,
Qwen3ForCausalLM
,
Gemma3ForConditionalGeneration
,
)
@
dataclass
class
BasePress
:
"""
Base class for all KV cache compression methods.
This class provides the foundation for implementing various key-value cache compression
techniques. Subclasses must implement the `compress` method to define their specific
compression logic.
The compression is applied only during pre-filling (not during generation).
"""
def
post_init_from_model
(
self
,
model
:
PreTrainedModel
):
"""
Optional method to initialize press parameters from the model
"""
pass
def
compress
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
:
dict
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
The core logic of the compression method.
Parameters
----------
module : nn.Module
The transformer attention layer where compression is applied.
hidden_states : torch.Tensor
Hidden states of the current layer with shape (batch_size, seq_len, hidden_dim).
These represent the input to the attention layer.
keys : torch.Tensor
Key tensors from the KV cache with shape (batch_size, num_kv_heads, seq_len, head_dim).
These are keys ready for compression.
values : torch.Tensor
Value tensors from the KV cache with shape (batch_size, num_kv_heads, seq_len, head_dim).
These are values ready for compression.
attentions : torch.Tensor
Attention weights from the layer with shape (batch_size, num_heads, seq_len, seq_len).
May be None if attention weights are not computed or needed.
kwargs : dict
Additional keyword arguments from the forward pass.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
A tuple containing the compressed keys and values tensors. The returned tensors
should have reduced sequence length dimension compared to the input tensors.
"""
raise
NotImplementedError
(
"compress method must be implemented in subclass"
)
def
forward_hook
(
self
,
module
:
nn
.
Module
,
input
:
list
[
torch
.
Tensor
],
kwargs
:
dict
,
output
:
list
):
"""
Default forward hook called after the forward pass of an attention layer.
This hook automatically applies compression during the pre-filling phase by:
1. Checking if we're still in pre-filling (not generation) phase
2. Extracting keys and values from the cache (handling quantization)
3. Calling the compress method to reduce the cache size
4. Updating the cache with compressed keys and values
The hook ensures compression is only applied during pre-filling and correctly
handles both quantized and unquantized caches.
Parameters
----------
module : nn.Module
The transformer attention layer.
input : list[torch.Tensor]
Input tensors to the forward pass of the attention layer. This parameter
is provided by PyTorch's hook mechanism but not used in the default implementation.
kwargs : dict
Keyword arguments passed to the attention layer's forward method, including:
- hidden_states: Input embeddings to the attention layer
- past_key_values: The KV cache object being modified
- cache_position: Position indices indicating where we are in the sequence
- position_embeddings: RoPE embeddings if applicable
output : list
Output from the attention layer's forward pass. Contains:
- [0]: Hidden states output
- [1]: Attention weights (may be None)
Returns
-------
list
The potentially modified output from the forward pass. This
is the same as the input output, but the underlying cache has been compressed in-place.
"""
hidden_states
=
kwargs
[
"hidden_states"
]
cache
=
kwargs
[
"past_key_values"
]
cache_layer
=
cache
.
layers
[
module
.
layer_idx
]
q_len
=
hidden_states
.
shape
[
1
]
# Don't compress after pre-filling
if
kwargs
[
"cache_position"
][
-
1
]
>
q_len
:
return
output
keys
,
values
=
extract_keys_and_values
(
cache
,
module
.
layer_idx
)
keys
,
values
=
self
.
compress
(
module
,
hidden_states
,
keys
,
values
,
output
[
1
],
kwargs
)
if
isinstance
(
cache
,
QuantizedCache
):
cache_layer
.
_quantized_keys
=
cache_layer
.
_quantize
(
keys
,
axis
=
cache_layer
.
axis_key
)
cache_layer
.
_quantized_values
=
cache_layer
.
_quantize
(
values
,
axis
=
cache_layer
.
axis_value
)
cache_layer
.
keys
=
torch
.
zeros
(
0
,
dtype
=
keys
.
dtype
,
device
=
keys
.
device
)
# type: ignore[index]
cache_layer
.
values
=
torch
.
zeros
(
0
,
dtype
=
keys
.
dtype
,
device
=
keys
.
device
)
# type: ignore[index]
cache_layer
.
cumulative_length
=
keys
.
shape
[
2
]
else
:
cache_layer
.
keys
=
keys
cache_layer
.
values
=
values
return
output
@
contextmanager
def
__call__
(
self
,
model
:
PreTrainedModel
)
->
Generator
:
"""
Context manager to apply a compression method to a model.
This method registers forward hooks on all attention layers of the model to enable
automatic KV cache compression during the pre-filling phase. The hooks are automatically
removed when exiting the context manager.
Apply this context manager during the pre-filling phase to compress the context.
Parameters
----------
model : PreTrainedModel
The transformer model to apply compression to.
Examples
--------
>>> from kvpress import KnormPress
>>> press = KnormPress(compression_ratio=0.5)
>>> with press(model):
... # Forward pass with compression applied
... outputs = model(input_ids, past_key_values=cache)
"""
if
not
isinstance
(
model
,
SUPPORTED_MODELS
):
logger
.
warning
(
f
"Model
{
type
(
model
)
}
not tested, supported models:
{
SUPPORTED_MODELS
}
"
)
if
isinstance
(
model
,
Gemma3ForConditionalGeneration
):
logger
.
warning_once
(
"Compression in Gemma3 is only applied to layer without sliding window attention"
)
self
.
post_init_from_model
(
model
)
hooks
=
[]
try
:
language_model
=
model
.
model
.
language_model
if
hasattr
(
model
.
model
,
"language_model"
)
else
model
.
model
for
layer
in
language_model
.
layers
:
if
isinstance
(
model
,
Gemma3ForConditionalGeneration
)
and
layer
.
self_attn
.
is_sliding
:
# Skip layers with sliding window attention, only for Gemma3
continue
layer
.
self_attn
.
rotary_emb
=
language_model
.
rotary_emb
hooks
.
append
(
layer
.
self_attn
.
register_forward_hook
(
self
.
forward_hook
,
with_kwargs
=
True
))
yield
finally
:
for
forward_hook
in
hooks
:
forward_hook
.
remove
()
kvpress/kvpress/presses/block_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
BlockPress
(
BasePress
):
"""
BlockPress: Block-wise iterative KV cache compression.
Simulates block prompt processing described in the KeyDiff paper (https://arxiv.org/abs/2504.15364).
Segments the input sequence into non-overlapping blocks and compresses iteratively.
⚠️ This is not a true chunked-prefill implementation: all inputs are computed in a single
forward pass before block-wise scoring and pruning.
Parameters
----------
press : ScorerPress
The underlying scoring method used to evaluate token importance within each block.
block_size : int, default=128
Size of each block for iterative compression.
"""
press
:
ScorerPress
block_size
:
int
=
128
def
__post_init__
(
self
):
assert
isinstance
(
self
.
press
,
ScorerPress
),
"BlockPress requires a ScorerPress"
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
def
compression_ratio
(
self
):
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
def
compress
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
:
dict
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
press
.
compression_ratio
==
0
:
return
keys
,
values
assert
attentions
is
None
,
"BlockPress does not support attentions."
bsz
,
num_key_value_heads
,
k_len
,
head_dim
=
keys
.
shape
block_size
=
self
.
block_size
if
self
.
block_size
<
k_len
else
k_len
n_kept
=
int
(
k_len
*
(
1
-
self
.
compression_ratio
))
kept_indices
=
torch
.
arange
(
n_kept
,
device
=
keys
.
device
).
expand
(
bsz
,
num_key_value_heads
,
-
1
)
# Reshape hidden states to match the kept_indices
states
=
hidden_states
.
view
(
bsz
,
k_len
,
num_key_value_heads
,
-
1
).
transpose
(
1
,
2
)
for
i
in
range
(
n_kept
,
k_len
,
block_size
):
end
=
min
(
i
+
block_size
,
k_len
)
current_indices
=
torch
.
arange
(
i
,
end
,
device
=
keys
.
device
).
expand
(
bsz
,
num_key_value_heads
,
-
1
)
current_indices
=
torch
.
cat
([
kept_indices
,
current_indices
],
dim
=-
1
)
# Gather hidden states for the selected indices, then restore the shape
# Check tests/presses/test_block_press.py for correctness verification of gathered hidden states
current_states
=
states
.
gather
(
2
,
current_indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
states
.
shape
[
-
1
]))
current_states
=
current_states
.
transpose
(
1
,
2
).
reshape
(
bsz
,
-
1
,
hidden_states
.
shape
[
-
1
])
scores
=
self
.
press
.
score
(
module
,
current_states
,
keys
.
gather
(
2
,
current_indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
head_dim
)),
values
.
gather
(
2
,
current_indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
head_dim
)),
attentions
,
kwargs
,
)
topk_indices
=
scores
.
topk
(
n_kept
,
dim
=-
1
).
indices
kept_indices
=
current_indices
.
gather
(
-
1
,
topk_indices
)
kept_indices
=
kept_indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
head_dim
)
keys
=
keys
.
gather
(
2
,
kept_indices
).
contiguous
()
values
=
values
.
gather
(
2
,
kept_indices
).
contiguous
()
return
keys
,
values
kvpress/kvpress/presses/chunk_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
ChunkPress
(
BasePress
):
"""
ChunkPress: Uniform compression through independent chunk processing.
This wrapper enhances any ScorerPress by applying compression independently
to fixed-size chunks of the sequence. Unlike global compression methods that
may concentrate selection in high-importance regions, ChunkPress ensures
uniform compression across the entire context by processing each chunk separately.
Based on FINCH (https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280).
Parameters
----------
press : ScorerPress
The underlying scoring method to apply to each chunk independently.
chunk_length : int, default=1024
Length of each chunk for independent compression.
"""
press
:
ScorerPress
chunk_length
:
int
=
1024
def
__post_init__
(
self
):
assert
isinstance
(
self
.
press
,
ScorerPress
),
"ChunkPress requires a ScorerPress as input"
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
def
compression_ratio
(
self
):
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
def
compress
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
:
dict
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
press
.
compression_ratio
==
0
:
return
keys
,
values
assert
attentions
is
None
,
"ChunkPress does not support attentions."
kv_len
=
keys
.
shape
[
2
]
indices
=
[]
for
i
in
range
(
0
,
kv_len
,
self
.
chunk_length
):
chunk_scores
=
self
.
press
.
score
(
module
,
hidden_states
[:,
i
:
i
+
self
.
chunk_length
],
keys
[:,
:,
i
:
i
+
self
.
chunk_length
],
values
[:,
:,
i
:
i
+
self
.
chunk_length
],
attentions
,
kwargs
,
)
chunk_length
=
keys
[:,
:,
i
:
i
+
self
.
chunk_length
].
shape
[
2
]
n_kept
=
max
(
1
,
int
(
chunk_length
*
(
1
-
self
.
press
.
compression_ratio
)))
chunk_indices
=
i
+
chunk_scores
.
topk
(
n_kept
,
dim
=-
1
).
indices
indices
.
append
(
chunk_indices
)
indices
=
torch
.
cat
(
indices
,
dim
=-
1
)
indices
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
module
.
head_dim
)
keys
=
keys
.
gather
(
2
,
indices
).
contiguous
()
values
=
values
.
gather
(
2
,
indices
).
contiguous
()
return
keys
,
values
kvpress/kvpress/presses/chunkkv_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
ChunkKVPress
(
BasePress
):
"""
ChunkKV: Semantic-preserving compression with chunk-wise token selection.
Enhances any ScorerPress by applying chunk-wise token selection instead of
global selection. Computes global importance scores, then selects tokens
chunk by chunk to preserve semantic coherence within local contexts.
Based on ChunkKV (https://arxiv.org/abs/2502.00299).
Parameters
----------
press : ScorerPress
The underlying scoring method used to compute global importance scores.
chunk_length : int, default=20
Length of each chunk for token selection.
Sequence is divided into chunks of this size, with tokens selected
proportionally from each chunk based on global importance scores.
"""
press
:
ScorerPress
chunk_length
:
int
=
20
def
__post_init__
(
self
):
assert
isinstance
(
self
.
press
,
ScorerPress
),
"ChunkKVPress requires a ScorerPress as input"
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
def
compression_ratio
(
self
):
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
def
compress
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
:
dict
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
press
.
compression_ratio
==
0
:
return
keys
,
values
assert
attentions
is
None
,
"ChunkPress does not support attentions."
kv_len
=
keys
.
shape
[
2
]
# 1. Calculate global scores first
global_scores
=
self
.
press
.
score
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
,
)
# 2. Calculate actual number of complete chunks and remaining tokens
num_complete_chunks
=
kv_len
//
self
.
chunk_length
remaining_tokens
=
kv_len
%
self
.
chunk_length
# If we have no complete chunks, delegate to the underlying scorer press
if
num_complete_chunks
==
0
:
return
self
.
press
.
compress
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
# Reshape complete chunks for score calculation
if
num_complete_chunks
>
0
:
main_scores
=
global_scores
[...,
:
num_complete_chunks
*
self
.
chunk_length
]
main_chunk_scores
=
main_scores
.
sum
(
dim
=
1
).
view
(
-
1
,
num_complete_chunks
,
self
.
chunk_length
)
main_chunk_scores
=
main_chunk_scores
.
mean
(
dim
=-
1
)
else
:
main_chunk_scores
=
torch
.
empty
((
global_scores
.
shape
[
0
],
0
),
device
=
global_scores
.
device
)
# Handle remaining tokens if any
if
remaining_tokens
>
0
:
remaining_scores
=
global_scores
[...,
-
remaining_tokens
:]
remaining_chunk_score
=
remaining_scores
.
sum
(
dim
=
1
).
mean
(
dim
=-
1
,
keepdim
=
True
)
chunk_scores
=
torch
.
cat
([
main_chunk_scores
,
remaining_chunk_score
],
dim
=-
1
)
else
:
chunk_scores
=
main_chunk_scores
# 3. Calculate number of chunks to keep
n_chunks_kept
=
max
(
1
,
int
((
num_complete_chunks
+
(
remaining_tokens
>
0
))
*
(
1
-
self
.
press
.
compression_ratio
)))
top_chunks
=
chunk_scores
.
topk
(
n_chunks_kept
,
dim
=-
1
)
# 4. Create indices for selected chunks
indices
=
[]
for
chunk_idx
in
top_chunks
.
indices
[
0
]:
if
chunk_idx
<
num_complete_chunks
:
# For complete chunks
start_idx
=
chunk_idx
*
self
.
chunk_length
chunk_indices
=
torch
.
arange
(
start_idx
,
start_idx
+
self
.
chunk_length
,
device
=
keys
.
device
)
else
:
# For the remaining partial chunk
chunk_indices
=
torch
.
arange
(
num_complete_chunks
*
self
.
chunk_length
,
kv_len
,
device
=
keys
.
device
)
indices
.
append
(
chunk_indices
)
indices
=
torch
.
cat
(
indices
).
sort
()[
0
]
indices
=
indices
.
view
(
1
,
1
,
-
1
,
1
).
expand
(
keys
.
shape
[
0
],
keys
.
shape
[
1
],
-
1
,
module
.
head_dim
)
# 5. Use gather to collect selected keys and values
keys
=
keys
.
gather
(
2
,
indices
).
contiguous
()
values
=
values
.
gather
(
2
,
indices
).
contiguous
()
return
keys
,
values
kvpress/kvpress/presses/compactor_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
kvpress.presses.leverage_press
import
LeverageScorePress
from
kvpress.presses.non_causal_attention_press
import
NonCausalAttnPress
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
CompactorPress
(
ScorerPress
):
"""
Compactor: Calibrated Query-Agnostic KV Cache Compression with Approximate Leverage Scores
Compactor blends: geometry-based outlier scores via (approximate) statistical leverage on key
embeddings; and non-causal, chunked attention. Currently only supports prefill. The presented
version slightly differs from the paper in that: (1) we set blending=compression_ratio by default,
which is a good heuristic and should work for most users and (2) we use a cholesky
decomposition to compute the leverage scores. Please see the paper for an in-depth discussion. The
press is implemented as a wrapper that combines ``NonCausalAttnPress`` and
``LeverageScorePress`` scores.
References:
- Chari & Van Durme (2025): "Compactor: Calibrated Query-Agnostic KV Cache
Compression with Approximate Leverage Scores" (https://arxiv.org/pdf/2507.08143v1)
Parameters
----------
compression_ratio : float, default ``0.0``
Fraction of key-value pairs to remove during compression.
sink_size_start : int, default ``8``
Number of initial sink tokens to always protect.
sink_size_end : int, default ``4``
Number of most-recent tokens to always protect.
chunk_size : int, default ``256``
Chunk size used to in non-causal attention.
sketch_dimension: int, default ``48``
Size of Gaussian sketch.
blending : Optional[float], default ``None``
Weight for blending scores in the final output. If ``None``,
it set to ``compression_ratio``, which tends to be a good default.
Output
------
score(...) returns a tensor of shape (B, H_kv, S) with higher values
indicating more important tokens for retention.
"""
sink_size_start
:
int
=
8
sink_size_end
:
int
=
4
chunk_size
:
int
=
256
sketch_dimension
:
int
=
48
blending
:
Optional
[
float
]
=
None
_leverage_press
:
Optional
[
LeverageScorePress
]
=
None
_non_causal_press
:
Optional
[
NonCausalAttnPress
]
=
None
def
__post_init__
(
self
):
# build child presses if not provided
self
.
_leverage_press
=
LeverageScorePress
(
compression_ratio
=
self
.
compression_ratio
,
sketch_dimension
=
self
.
sketch_dimension
)
self
.
_non_causal_press
=
NonCausalAttnPress
(
compression_ratio
=
self
.
compression_ratio
,
chunk_size
=
self
.
chunk_size
)
def
__setattr__
(
self
,
name
,
value
):
object
.
__setattr__
(
self
,
name
,
value
)
if
name
==
"compression_ratio"
:
if
"_leverage_press"
in
self
.
__dict__
:
self
.
_leverage_press
.
compression_ratio
=
value
if
"_non_causal_press"
in
self
.
__dict__
:
self
.
_non_causal_press
.
compression_ratio
=
value
if
name
==
"sketch_dimension"
:
if
"_leverage_press"
in
self
.
__dict__
:
self
.
_leverage_press
.
sketch_dimension
=
value
if
name
==
"chunk_size"
:
if
"_non_causal_press"
in
self
.
__dict__
:
self
.
_non_causal_press
.
chunk_size
=
value
def
score
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
,
)
->
torch
.
Tensor
:
"""Blend leverage and non-causal attention into final importance scores"""
n_queries
=
hidden_states
.
shape
[
-
2
]
assert
keys
.
shape
[
-
2
]
==
n_queries
,
"CompactorPress only supports prefill at the moment"
left_keep
=
min
(
self
.
sink_size_start
,
n_queries
)
right_keep
=
min
(
self
.
sink_size_end
,
max
(
0
,
n_queries
-
left_keep
))
start_idx
,
end_idx
=
left_keep
,
(
None
if
right_keep
==
0
else
-
right_keep
)
hs
=
hidden_states
[:,
start_idx
:
end_idx
]
keys
=
keys
[...,
start_idx
:
end_idx
,
:]
values
=
values
[...,
start_idx
:
end_idx
,
:]
cos
,
sin
=
kwargs
[
"position_embeddings"
]
sliced_kwargs
=
{
"position_embeddings"
:
(
cos
[...,
start_idx
:
end_idx
,
:],
sin
[...,
start_idx
:
end_idx
,
:])}
l_scores
=
self
.
_leverage_press
.
score
(
module
=
module
,
hidden_states
=
hs
,
keys
=
keys
,
values
=
values
,
attentions
=
attentions
,
kwargs
=
sliced_kwargs
)
attn_scores
=
self
.
_non_causal_press
.
score
(
module
=
module
,
hidden_states
=
hs
,
keys
=
keys
,
values
=
values
,
attentions
=
attentions
,
kwargs
=
sliced_kwargs
)
# sanity check. this breaks when not in prefill
assert
attn_scores
.
shape
==
l_scores
.
shape
,
"CompactorPress only supports prefill at the moment"
blending
=
self
.
blending
if
self
.
blending
is
not
None
else
self
.
compression_ratio
blending
=
0.35
if
blending
is
None
else
blending
scores
=
blending
*
l_scores
+
attn_scores
# protect sinks by padding
scores
=
F
.
pad
(
scores
,
(
left_keep
,
right_keep
),
value
=
scores
.
detach
().
max
())
return
scores
kvpress/kvpress/presses/composed_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
kvpress.presses.adakv_press
import
AdaKVPress
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.kvzip_press
import
KVzipPress
@
dataclass
class
ComposedPress
(
BasePress
):
"""
Composed compression: Chain multiple compression methods sequentially.
Applies multiple compression methods in sequence, with each method operating
on the output of the previous one. Useful for combining complementary approaches
like sequence + dimension compression.
Example:
```python
press = ComposedPress([
SnapKVPress(compression_ratio=0.3),
ThinKPress(key_channel_compression_ratio=0.2)
])
```
AdaKVPress and KVzipPress are currently not supported.
⚠️ ComposedPress may fail if a press depends on features beyond keys and values
(e.g., hidden states or attention weights). For example, combining KnormPress
with ObservedAttentionPress fails because KnormPress prunes keys and values,
but ObservedAttentionPress then receives the original attention weights.
Parameters
----------
presses : list[BasePress]
List of compression methods to apply sequentially.
Methods are applied in order, with each operating on the compressed output
of the previous method. Final compression ratio is the product of all ratios.
"""
presses
:
list
[
BasePress
]
def
__post_init__
(
self
):
self
.
compression_ratio
=
None
assert
not
any
(
isinstance
(
press
,
(
AdaKVPress
,
KVzipPress
))
for
press
in
self
.
presses
),
"ComposedPress cannot contains AdaKVPress or KVzipPress"
def
post_init_from_model
(
self
,
model
):
for
press
in
self
.
presses
:
press
.
post_init_from_model
(
model
)
def
forward_hook
(
self
,
module
,
input
,
kwargs
,
output
):
retained_fraction
=
1.0
for
press
in
self
.
presses
:
output
=
press
.
forward_hook
(
module
,
input
,
kwargs
,
output
)
retained_fraction
*=
1
-
press
.
compression_ratio
# type: ignore
self
.
compression_ratio
=
1
-
retained_fraction
return
output
kvpress/kvpress/presses/criticalkv_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
from
dataclasses
import
dataclass
import
torch
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.expected_attention_press
import
ExpectedAttentionPress
from
kvpress.presses.scorer_press
import
ScorerPress
logger
=
logging
.
getLogger
(
__name__
)
class
CriticalKVPress
(
ScorerPress
):
"""
CriticalKV: Two-stage compression with output projection weighting.
Enhances existing scoring methods by rescaling scores using the L1 norm
of output projection applied to values (Wo @ values).
Based on CriticalKV (https://arxiv.org/abs/2502.03805).
Parameters
----------
press : ScorerPress
Base scoring method to enhance with output projection weighting.
epsilon : float, default=1e-4
Small value for numerical stability in score rescaling.
first_stage_ratio : float, default=0.5
Fraction of compression budget allocated to first stage selection.
Remaining budget used in second stage with output projection weighting.
"""
def
__init__
(
self
,
press
:
ScorerPress
,
epsilon
:
float
=
1e-4
,
first_stage_ratio
:
float
=
0.5
):
self
.
press
=
press
self
.
epsilon
=
epsilon
self
.
first_stage_ratio
=
first_stage_ratio
assert
isinstance
(
self
.
press
,
ScorerPress
),
"CriticalKVPress requires a ScorerPress as input"
if
isinstance
(
self
.
press
,
ExpectedAttentionPress
)
and
self
.
press
.
use_vnorm
:
logger
.
warning
(
"use_vnorm should be disabled for CriticalKVPress"
)
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
# type: ignore[misc]
def
compression_ratio
(
self
):
#
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
@
staticmethod
def
vwl1norm
(
values
,
module
):
bsz
,
num_key_value_heads
,
k_len
,
_
=
values
.
shape
num_key_value_groups
=
module
.
config
.
num_attention_heads
//
num_key_value_heads
Wo
=
module
.
o_proj
.
weight
.
transpose
(
0
,
1
)
Wo
=
Wo
.
view
(
module
.
config
.
num_attention_heads
,
module
.
config
.
head_dim
,
module
.
config
.
hidden_size
)
V
=
repeat_kv
(
values
,
num_key_value_groups
)
# We use head-wise computation instead of direct matmul to reduce the memory usage of WoV.
# Future kernel fusion optimization could eliminate this intermediate variables to enhance performance.
head_WoV_norm_list
=
[]
for
head
in
range
(
V
.
size
(
1
)):
head_WoV
=
V
[:,
head
,
:,
...].
matmul
(
Wo
[
head
,
...].
unsqueeze
(
0
))
head_WoV_norm
=
torch
.
norm
(
head_WoV
,
p
=
1
,
dim
=-
1
)
head_WoV_norm_list
.
append
(
head_WoV_norm
)
# b_size, num_heads, k_len , k_len
WoV_norm
=
torch
.
stack
(
head_WoV_norm_list
,
dim
=
1
)
WoV_norm
=
WoV_norm
.
view
(
bsz
,
num_key_value_heads
,
module
.
num_key_value_groups
,
k_len
).
mean
(
dim
=
2
)
return
WoV_norm
def
score
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
# Stage 1
scores
=
self
.
press
.
score
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
k_len
=
keys
.
shape
[
2
]
selection_budget
=
int
((
1
-
self
.
compression_ratio
)
*
k_len
*
self
.
first_stage_ratio
)
top_k_index
=
torch
.
topk
(
scores
,
selection_budget
,
sorted
=
True
,
dim
=-
1
).
indices
# Stage 2
projected_norm
=
self
.
vwl1norm
(
values
,
module
)
scores
=
(
scores
+
self
.
epsilon
)
*
projected_norm
# Merge the two stages
scores
.
scatter_
(
-
1
,
top_k_index
,
torch
.
finfo
(
scores
.
dtype
).
max
)
return
scores
@
dataclass
class
CriticalAdaKVPress
(
BasePress
):
"""
CriticalAdaKV: Combined two-stage compression with adaptive head-wise selection.
Combines output projection weighting from CriticalKV with adaptive head-wise
compression from AdaKV. Provides both accurate importance estimation and
head-specific compression adaptation.
Based on CriticalAdaKV (https://arxiv.org/abs/2502.03805).
Parameters
----------
press : ScorerPress
The underlying scoring method used to evaluate token importance.
alpha_safeguard : float, default=0.20
Minimum fraction of KV pairs that each head must retain. (see AdaKVPress)
epsilon : float, default=1e-4
Small value for numerical stability in score rescaling.
first_stage_ratio : float, default=0.5
Fraction of compression budget allocated to first stage selection.
"""
press
:
ScorerPress
=
None
alpha_safeguard
:
float
=
0.20
epsilon
:
float
=
1e-4
first_stage_ratio
:
float
=
0.5
def
__post_init__
(
self
):
assert
0
<=
self
.
alpha_safeguard
<=
1
,
"alpha_safeguard should be in 0, 1]"
assert
isinstance
(
self
.
press
,
ScorerPress
),
"CriticalAdaKVPress requires a ScorerPress as input"
if
isinstance
(
self
.
press
,
ExpectedAttentionPress
)
and
self
.
press
.
use_vnorm
:
logger
.
warning
(
"use_vnorm should be disabled for CriticalAdaKVPress"
)
@
property
def
compression_ratio
(
self
):
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
def
compress
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
if
self
.
compression_ratio
==
0
:
return
keys
,
values
assert
module
.
config
.
_attn_implementation
!=
"eager"
,
"eager mode not supported"
# Compute scores
scores
=
self
.
press
.
score
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
bsz
,
num_key_value_heads
,
k_len
=
scores
.
shape
# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
n_kept
=
int
(
k_len
*
(
1
-
self
.
compression_ratio
))
# ScorerPress definition
n_safe
=
int
(
n_kept
*
self
.
alpha_safeguard
)
top_indices
=
torch
.
topk
(
scores
,
n_safe
,
dim
=-
1
).
indices
scores
.
scatter_
(
-
1
,
top_indices
,
torch
.
finfo
(
scores
.
dtype
).
max
)
############################
# Start of CriticalKV code #
############################
# Budget allocation
budget_scores
=
scores
.
scatter
(
-
1
,
top_indices
,
torch
.
finfo
(
scores
.
dtype
).
max
)
budget_scores
=
budget_scores
.
reshape
(
bsz
,
-
1
)
top_indices
=
torch
.
topk
(
budget_scores
,
n_kept
*
num_key_value_heads
,
dim
=-
1
).
indices
top_indices_head_idx
=
top_indices
//
k_len
head_budgets
=
torch
.
zeros
(
num_key_value_heads
,
device
=
keys
.
device
,
dtype
=
torch
.
int64
)
head_budgets
.
scatter_add_
(
0
,
top_indices_head_idx
.
flatten
(),
torch
.
ones_like
(
top_indices_head_idx
.
flatten
()))
# Stage 1
head_selection_budget_1st
=
(
head_budgets
*
self
.
first_stage_ratio
).
to
(
torch
.
int64
).
tolist
()
top_k_index
=
torch
.
topk
(
scores
,
max
(
head_selection_budget_1st
),
sorted
=
True
,
dim
=-
1
).
indices
for
head_idx
in
range
(
num_key_value_heads
):
phase1_budget
=
head_selection_budget_1st
[
head_idx
]
scores
[:,
head_idx
,
:].
scatter_
(
-
1
,
top_k_index
[:,
head_idx
,
:
phase1_budget
],
torch
.
finfo
(
scores
.
dtype
).
max
)
# Stage 2
projected_norm
=
CriticalKVPress
.
vwl1norm
(
values
,
module
)
scores
=
(
scores
+
self
.
epsilon
)
*
projected_norm
top_k_index
=
torch
.
topk
(
scores
,
max
(
head_budgets
),
sorted
=
True
,
dim
=-
1
).
indices
for
head_idx
in
range
(
num_key_value_heads
):
budget
=
head_budgets
[
head_idx
]
scores
[:,
head_idx
,
:].
scatter_
(
-
1
,
top_k_index
[:,
head_idx
,
:
budget
],
torch
.
finfo
(
scores
.
dtype
).
max
)
##########################
# End of CriticalKV code #
##########################
# Compute bottom-k across heads
n_pruned
=
num_key_value_heads
*
(
k_len
-
n_kept
)
indices
=
torch
.
topk
(
-
scores
.
reshape
(
bsz
,
-
1
),
n_pruned
,
dim
=
1
).
indices
.
flatten
()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices
=
torch
.
arange
(
bsz
).
repeat_interleave
(
n_pruned
)
head_indices
=
indices
//
k_len
seq_indices
=
indices
%
k_len
module
.
masked_key_indices
=
(
batch_indices
,
head_indices
,
seq_indices
)
return
keys
,
values
kvpress/kvpress/presses/cur_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
math
from
dataclasses
import
dataclass
from
typing
import
Literal
import
torch
import
torch.nn.functional
as
F
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
CURPress
(
ScorerPress
):
"""
Press based on `CurDKV` (https://arxiv.org/abs/2509.15038) which computes approximate leverage scores
for keys (k2) and values (v2) and combines them to prune the KV cache.
If `use_random_leverage` is true (default is False), keys and values are first
multiplied by a random projection matrix G.
If `use_local_approximation` is true (default), the scores are averaged over a
local window of size `local_window_size`.
Depending on `leverage_type`, returns either k2, v2, (k2 + v2) / 2, or k2 * v2 (default)
Finally, the first `num_sinks` tokens are set to 1.0 to preserve some initial "attention sinks".
"""
num_sinks
:
int
=
4
leverage_type
:
Literal
[
"key"
,
"value"
,
"kv_avg"
,
"kv_product"
]
=
"kv_product"
use_random_leverage
:
bool
=
False
use_local_approximation
:
bool
=
True
local_window_size
:
int
=
16
def
score
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
if
self
.
use_random_leverage
:
r
=
20
G
=
torch
.
randn
(
keys
.
shape
[
-
1
],
r
,
device
=
keys
.
device
)
/
math
.
sqrt
(
r
)
keys
=
keys
@
G
values
=
values
@
G
k2
=
(
keys
**
2
).
sum
(
dim
=-
1
)
v2
=
(
values
**
2
).
sum
(
dim
=-
1
)
if
self
.
use_local_approximation
:
b
,
h
,
n
=
k2
.
shape
w
=
self
.
local_window_size
k2
=
F
.
pad
(
k2
,
(
0
,
(
w
-
n
%
w
)
%
w
)).
reshape
(
b
,
h
,
-
1
,
w
)
k2
=
(
k2
/
k2
.
sum
(
dim
=-
1
,
keepdim
=
True
)).
reshape
(
b
,
h
,
-
1
)[:,
:,
:
n
]
v2
=
F
.
pad
(
v2
,
(
0
,
(
w
-
n
%
w
)
%
w
)).
reshape
(
b
,
h
,
-
1
,
w
)
v2
=
(
v2
/
v2
.
sum
(
dim
=-
1
,
keepdim
=
True
)).
reshape
(
b
,
h
,
-
1
)[:,
:,
:
n
]
if
self
.
leverage_type
==
"key"
:
scores
=
k2
elif
self
.
leverage_type
==
"value"
:
scores
=
v2
elif
self
.
leverage_type
==
"kv_avg"
:
scores
=
(
k2
+
v2
)
/
2
elif
self
.
leverage_type
==
"kv_product"
:
scores
=
k2
*
v2
else
:
raise
ValueError
(
"Unknown leverage type: choose from 'kv_avg', 'key', 'value' or 'kv_product'"
)
scores
/=
scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
scores
[:,
:,
:
self
.
num_sinks
]
=
1.0
return
scores
kvpress/kvpress/presses/decoding_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
from
collections
import
defaultdict
from
dataclasses
import
dataclass
import
torch
import
torch.nn
as
nn
from
transformers.cache_utils
import
QuantizedCache
from
kvpress.presses.adakv_press
import
AdaKVPress
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
from
kvpress.utils
import
extract_keys_and_values
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
DecodingPress
(
BasePress
):
"""
A press that only operates during decoding phase and maintains a running buffer of hidden states.
This press accumulates hidden states during decoding and applies compression every N steps
using a scorer press to determine which tokens to keep.
Parameters
----------
base_press : ScorerPress
The scorer press used to compute importance scores for tokens.
compression_interval : int, default=512
Number of decoding steps between compression, i.e. compression will be applied every compression_interval steps.
target_size : int, default=2048
Target number of tokens to keep after compression.
hidden_states_buffer_size : int, default=256
Maximum number of hidden states to keep before compression. Larger values use more GPU memory.
Note: Some presses don't need buffered hidden states and can set this to 0 to use only the
current hidden state for compression scoring.
"""
base_press
:
ScorerPress
|
AdaKVPress
compression_interval
:
int
=
512
target_size
:
int
=
2048
hidden_states_buffer_size
:
int
=
256
def
__post_init__
(
self
):
# Buffer to store hidden states during decoding (per layer)
assert
isinstance
(
self
.
base_press
,
(
ScorerPress
,
AdaKVPress
)),
"DecodingPress requires a ScorerPress as input"
self
.
hidden_states_buffer
=
defaultdict
(
list
)
# Per-layer buffer
self
.
layer_step_counts
=
defaultdict
(
int
)
# Track step count per layer
assert
self
.
compression_interval
>
0
,
"compression_interval must be greater than 0"
assert
self
.
target_size
>
0
,
"target_size must be greater than 0"
if
self
.
base_press
.
compression_ratio
:
logger
.
warning
(
f
"compression_ratio is set for base press (
{
self
.
base_press
.
compression_ratio
}
). "
f
"This will be overridden by the decoding press."
)
def
post_init_from_model
(
self
,
model
):
self
.
base_press
.
post_init_from_model
(
model
)
def
compress
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
:
dict
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Delegate compression to the base press during decoding phase.
Args:
module: The transformer module being compressed
hidden_states: Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim])
keys: Key cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim])
values: Value cache from all previous steps including current (shape: [batch, n_heads, seq_len, head_dim])
attentions: Attention weights (shape varies by implementation)
kwargs: Additional keyword arguments
Returns:
tuple[torch.Tensor, torch.Tensor]: Compressed (keys, values) tensors
Note:
**Sequence length alignment**: During decoding compression, `hidden_states` contains the
buffered hidden states from recent decoding steps (buffer_len tokens), while `keys` and
`values` contain the full sequence history (seq_len tokens). The base press implementation
should use keys.shape[2] for full sequence length calculations. The buffered hidden_states
provide context for the most recent tokens when computing compression scores.
Performance Note:
It would be possible to speed up compression during decoding for certain scorer presses by
storing existing scores in a buffer (e.g. KNormPress) and reusing them in subsequent compressions.
"""
k_len
=
keys
.
shape
[
2
]
target_compression_ratio
=
self
.
_find_target_compression_ratio
(
k_len
,
self
.
target_size
)
logger
.
debug
(
f
"Compressing
{
k_len
}
to
{
self
.
target_size
}
with ratio
{
target_compression_ratio
}
"
)
original_compression_ratio
=
self
.
base_press
.
compression_ratio
self
.
base_press
.
compression_ratio
=
target_compression_ratio
result
=
self
.
base_press
.
compress
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
self
.
base_press
.
compression_ratio
=
original_compression_ratio
return
result
def
forward_hook
(
self
,
module
:
nn
.
Module
,
input
:
list
[
torch
.
Tensor
],
kwargs
:
dict
,
output
:
list
):
"""
Forward hook that manages decoding-specific compression logic.
This hook:
1. Detects when we're in decoding phase (not prefilling)
2. Accumulates hidden states in a buffer
3. Applies compression every N steps
4. Clears the buffer after compression
"""
hidden_states
=
kwargs
[
"hidden_states"
]
cache
=
kwargs
[
"past_key_values"
]
q_len
=
hidden_states
.
shape
[
1
]
layer_idx
=
module
.
layer_idx
# Only operate during decoding phase (after prefilling)
if
kwargs
[
"cache_position"
][
-
1
]
<=
q_len
:
# We're still in prefilling phase, don't do anything
return
output
# print(f"Adding hidden states to buffer: {hidden_states.shape}")
# Add current hidden states to buffer for this layer
self
.
hidden_states_buffer
[
layer_idx
].
append
(
hidden_states
.
detach
().
clone
())
# print(f"Layer step counts: {self.layer_step_counts[layer_idx]}")
self
.
layer_step_counts
[
layer_idx
]
+=
1
# Apply compression if we've reached the compression step threshold
if
(
self
.
layer_step_counts
[
layer_idx
]
>=
self
.
compression_interval
)
or
(
q_len
>=
self
.
target_size
):
logger
.
debug
(
f
"Applying decoding compression: layer_step_count (
{
self
.
layer_step_counts
[
layer_idx
]
}
) >= compression_steps (
{
self
.
compression_interval
}
)"
# noqa: E501
)
cache_layer
=
cache
.
layers
[
module
.
layer_idx
]
keys
,
values
=
extract_keys_and_values
(
cache
,
module
.
layer_idx
)
# Get attention weights from output
attentions
=
output
[
1
]
if
len
(
output
)
>
1
and
output
[
1
]
is
not
None
else
None
# Apply compression using buffered hidden states for this layer
buffered_hidden_states
=
torch
.
cat
(
self
.
hidden_states_buffer
[
layer_idx
],
dim
=
1
)
keys
,
values
=
self
.
compress
(
module
,
buffered_hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
logger
.
debug
(
f
"Applied decoding compression: "
f
"keys.shape:
{
keys
.
shape
}
, values.shape:
{
values
.
shape
}
"
)
# Update cache with compressed keys and values
if
isinstance
(
cache
,
QuantizedCache
):
cache_layer
.
_quantized_keys
=
cache_layer
.
_quantize
(
keys
,
axis
=
cache_layer
.
axis_key
)
cache_layer
.
_quantized_values
=
cache_layer
.
_quantize
(
values
,
axis
=
cache_layer
.
axis_value
)
cache_layer
.
keys
=
torch
.
zeros
(
0
,
dtype
=
keys
.
dtype
,
device
=
keys
.
device
)
# type: ignore[index]
cache_layer
.
values
=
torch
.
zeros
(
0
,
dtype
=
keys
.
dtype
,
device
=
keys
.
device
)
# type: ignore[index]
cache_layer
.
cumulative_length
=
keys
.
shape
[
2
]
else
:
cache_layer
.
keys
=
keys
cache_layer
.
values
=
values
# Reset step count and clear buffer for this layer
self
.
layer_step_counts
[
layer_idx
]
=
0
# Always clear the buffer after compression - otherwise there's a mismatch between
# hidden states buffer and kv cache
self
.
hidden_states_buffer
[
layer_idx
]
=
[]
self
.
hidden_states_buffer
[
layer_idx
]
=
(
self
.
hidden_states_buffer
[
layer_idx
][
-
self
.
hidden_states_buffer_size
:]
if
self
.
hidden_states_buffer_size
>
0
else
[]
)
return
output
def
reset
(
self
):
"""Reset the decoding press state."""
self
.
hidden_states_buffer
=
defaultdict
(
list
)
self
.
layer_step_counts
=
defaultdict
(
int
)
def
_find_target_compression_ratio
(
self
,
q_len
:
int
,
target_tokens
:
int
)
->
float
:
"""
Find the compression ratio that results in exactly target_tokens after int() rounding.
Args:
q_len: Current sequence length
target_tokens: Desired number of tokens after compression
Returns:
Compression ratio that gives exactly target_tokens
"""
if
q_len
<=
target_tokens
:
return
0.0
# Start with theoretical ratio
ratio
=
1.0
-
(
target_tokens
/
q_len
)
# Binary search to handle int() rounding
low
,
high
=
0.0
,
1.0
max_iterations
=
20
iteration
=
0
while
iteration
<
max_iterations
:
n_kept
=
int
(
q_len
*
(
1
-
ratio
))
if
n_kept
==
target_tokens
:
break
elif
n_kept
>
target_tokens
:
# Need more compression
low
=
ratio
ratio
=
(
ratio
+
high
)
/
2
else
:
# Need less compression
high
=
ratio
ratio
=
(
low
+
ratio
)
/
2
iteration
+=
1
final_n_kept
=
int
(
q_len
*
(
1
-
ratio
))
if
final_n_kept
!=
target_tokens
:
logger
.
warning
(
f
"Binary search failed: q_len=
{
q_len
}
, target=
{
target_tokens
}
, got=
{
final_n_kept
}
, ratio=
{
ratio
}
"
)
return
ratio
kvpress/kvpress/presses/dms_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
from
kvpress.utils
import
extract_keys_and_values
@
dataclass
class
DMSPress
(
BasePress
):
"""
Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference.
Wraps a ScorerPress and evicts keys/values with scores below a given threshold.
This press implements a dense-prefill version of DMS, not the sparse-prefill version.
Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold
to determine which KV pairs to evict. This allows for adaptive compression where the actual
compression ratio depends on the input content.
Importantly, this press can be used both during prefilling and during decoding (if decoding=True).
A sliding window protects the most recent tokens from eviction, ensuring that recently
generated tokens are always available for attention.
Parameters
----------
press : ScorerPress
The underlying scorer press used to compute importance scores for each token.
threshold : float, optional
Tokens with scores below this threshold are evicted. The optimal threshold
depends on the scorer press being used.
sliding_window_size : int, default=128
Number of recent tokens protected from eviction.
decoding : bool, default=False
If True, compression is also applied during the decoding phase (token generation).
If False, compression only occurs during prefill.
"""
press
:
ScorerPress
threshold
:
Optional
[
float
]
=
None
sliding_window_size
:
int
=
128
decoding
:
bool
=
False
scores_buffer
:
dict
[
int
,
torch
.
Tensor
]
=
field
(
default_factory
=
dict
,
init
=
False
,
repr
=
False
)
compression_ratios
:
dict
[
int
,
float
]
=
field
(
default_factory
=
dict
,
init
=
False
,
repr
=
False
)
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
def
compression_ratio
(
self
):
"""Average compression ratio across all layers (computed after forward pass)."""
assert
len
(
self
.
compression_ratios
)
>
0
,
"Forward pass must be run to compute the compression ratio"
return
sum
(
self
.
compression_ratios
.
values
())
/
len
(
self
.
compression_ratios
)
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
"""Compression ratio is read-only since it depends on threshold and input content."""
raise
AttributeError
(
f
"compression ratio cannot be set for
{
type
(
self
).
__name__
}
"
)
def
forward_hook
(
self
,
module
:
nn
.
Module
,
input
:
list
[
torch
.
Tensor
],
kwargs
:
dict
,
output
:
list
):
hidden_states
=
kwargs
[
"hidden_states"
]
cache
=
kwargs
[
"past_key_values"
]
q_len
=
hidden_states
.
shape
[
1
]
cache_len
=
kwargs
[
"cache_position"
][
-
1
]
+
1
prefilling
=
cache_len
==
q_len
# Extract layer index as int for type safety
layer_idx
:
int
=
module
.
layer_idx
# type: ignore[assignment]
# Reset the scores buffer and compression ratios if we are in prefilling
if
prefilling
and
(
layer_idx
==
0
):
self
.
scores_buffer
.
clear
()
self
.
compression_ratios
.
clear
()
# Skip compression during decoding if not enabled
if
not
prefilling
and
not
self
.
decoding
:
return
output
# Compute importance scores for the new tokens using the underlying scorer press
keys
,
values
=
extract_keys_and_values
(
cache
,
layer_idx
)
scores
=
self
.
press
.
score
(
module
,
hidden_states
,
keys
[:,
:,
-
q_len
:],
values
[:,
:,
-
q_len
:],
None
,
kwargs
)
# Accumulate scores in the buffer: reset during prefill, append during decoding
if
prefilling
:
self
.
scores_buffer
[
layer_idx
]
=
scores
else
:
self
.
scores_buffer
[
layer_idx
]
=
torch
.
cat
([
self
.
scores_buffer
[
layer_idx
],
scores
],
dim
=-
1
)
# Once the buffer exceeds the sliding window, evict tokens with low scores
if
self
.
scores_buffer
[
layer_idx
].
shape
[
-
1
]
>
self
.
sliding_window_size
:
# Determine how many tokens have left the sliding window and can be evicted
n_to_evict
=
self
.
scores_buffer
[
layer_idx
].
shape
[
-
1
]
-
self
.
sliding_window_size
scores_to_evict
=
self
.
scores_buffer
[
layer_idx
][...,
:
n_to_evict
]
self
.
scores_buffer
[
layer_idx
]
=
self
.
scores_buffer
[
layer_idx
][...,
n_to_evict
:]
# Find tokens below threshold: returns (batch_idx, head_idx, token_idx) tuples
new_masked_key_indices
=
list
(
torch
.
where
(
scores_to_evict
<
self
.
threshold
))
if
len
(
new_masked_key_indices
[
0
])
>
0
:
# Convert buffer-relative indices to cache-absolute indices
# During prefill shift=0; during decoding we offset by the number of previously processed tokens
shift
=
cache_len
-
scores_to_evict
.
shape
[
2
]
-
self
.
sliding_window_size
new_masked_key_indices
[
-
1
]
+=
shift
# Merge new masked indices with existing ones
if
module
.
masked_key_indices
is
None
:
module
.
masked_key_indices
=
new_masked_key_indices
# type: ignore[assignment]
else
:
module
.
masked_key_indices
=
list
(
# type: ignore[assignment]
torch
.
cat
([
i
,
new_i
])
for
i
,
new_i
in
zip
(
module
.
masked_key_indices
,
new_masked_key_indices
)
)
# Track compression ratio as the fraction of masked tokens
if
module
.
masked_key_indices
is
not
None
:
bsz
,
num_key_value_heads
,
cache_len
,
_
=
keys
.
shape
n_masked
=
len
(
module
.
masked_key_indices
[
0
])
# type: ignore[index]
self
.
compression_ratios
[
layer_idx
]
=
n_masked
/
(
bsz
*
num_key_value_heads
*
cache_len
)
else
:
self
.
compression_ratios
[
layer_idx
]
=
0
return
output
kvpress/kvpress/presses/duo_attention_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
,
field
from
io
import
StringIO
import
numpy
as
np
import
requests
# type: ignore[import-untyped]
import
torch
from
cachetools
import
LRUCache
,
cached
# type: ignore[import-untyped]
from
datasets
import
load_dataset
from
transformers
import
AutoTokenizer
from
transformers.models.gemma3.modeling_gemma3
import
Gemma3Attention
from
transformers.models.llama.modeling_llama
import
apply_rotary_pos_emb
from
transformers.models.qwen3.modeling_qwen3
import
Qwen3Attention
from
kvpress.presses.base_press
import
BasePress
PATTERNS_DICT
=
{
"togethercomputer/Llama-2-7B-32K-Instruct"
:
"Llama-2-7B-32K-Instruct/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10"
,
# noqa: E501
"gradientai//Llama-3-8B-Instruct-Gradient-1048k"
:
"Llama-3-8B-Instruct-Gradient-1048k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10"
,
# noqa: E501
"gradientai//Llama-3-8B-Instruct-Gradient-4194k"
:
"Llama-3-8B-Instruct-Gradient-4194k/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10"
,
# noqa: E501
"meta-llama/Meta-Llama-3.1-8B-Instruct"
:
"Meta-Llama-3.1-8B-Instruct/lr=0.02-reg=0.05-ctx=1000_128000-multi_passkey10"
,
# noqa: E501
"mistralai/Mistral-7B-Instruct-v0.2"
:
"Mistral-7B-Instruct-v0.2/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10"
,
# noqa: E501
"mistralai/Mistral-7B-Instruct-v0.3"
:
"Mistral-7B-Instruct-v0.3/lr%3D0.02-reg%3D0.05-ctx%3D1000_32000-multi_passkey10"
,
# noqa: E501
}
cache
=
LRUCache
(
maxsize
=
128
)
@
dataclass
class
DuoAttentionPress
(
BasePress
):
"""
DuoAttention: Hybrid attention with retrieval and streaming heads.
Splits attention heads into two types:
- retrieval heads (use full KV cache) and
- streaming heads (use only sink + recent tokens).
Different heads have different attention patterns - some benefit from full context while others work well with
limited context.
Uses pre-computed attention patterns for supported models, falls back to
on-the-fly computation for unsupported models.
Based on DuoAttention (https://arxiv.org/abs/2410.10819).
Parameters
----------
head_compression_ratio : float, default=0.0
Fraction of attention heads to convert to streaming heads.
Controls balance between retrieval (full cache) and streaming (limited cache) heads.
on_the_fly_scoring : bool, default=False
Whether to compute attention patterns on-the-fly using random samples.
If True, computes patterns instead of loading pre-computed ones.
compression_ratio_ : float
Actual compression ratio achieved (computed during forward pass).
recent_size : int
Size of recent token window for streaming heads (determined automatically).
sink_size : int
Number of initial tokens preserved for streaming heads (determined automatically).
streaming_mask : torch.Tensor
Binary mask indicating which heads are streaming heads.
"""
head_compression_ratio
:
float
=
0.0
on_the_fly_scoring
:
bool
=
False
compression_ratio_
:
float
=
field
(
init
=
False
,
default
=
None
)
recent_size
:
int
=
field
(
init
=
False
,
default
=
None
)
sink_size
:
int
=
field
(
init
=
False
,
default
=
None
)
streaming_mask
:
torch
.
Tensor
=
field
(
init
=
False
,
default
=
None
)
def
post_init_from_model
(
self
,
model
):
"""
Initialize sink_size, recent_size, and streaming_mask from a model
"""
# Load attention pattern from the DuoAttention repo
if
self
.
on_the_fly_scoring
:
self
.
sink_size
,
self
.
recent_size
,
head_scores
=
128
,
256
,
duo_attention_on_the_fly
(
model
)
else
:
self
.
sink_size
,
self
.
recent_size
,
head_scores
=
self
.
load_attention_pattern
(
model
)
# Define retrieval and streaming heads through a binary mask
n_pruned
=
round
(
head_scores
.
size
*
self
.
head_compression_ratio
)
self
.
streaming_mask
=
torch
.
zeros
(
head_scores
.
shape
,
dtype
=
bool
,
device
=
model
.
device
)
if
n_pruned
>
0
:
indices
=
np
.
argsort
(
head_scores
,
axis
=
None
)[:
n_pruned
]
self
.
streaming_mask
[
np
.
unravel_index
(
indices
,
head_scores
.
shape
)]
=
True
@
property
def
compression_ratio
(
self
)
->
float
:
assert
self
.
compression_ratio_
is
not
None
,
"Forward pass must be run to compute the compression ratio"
return
self
.
compression_ratio_
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
raise
AttributeError
(
f
"compression ratio cannot be set for
{
type
(
self
).
__name__
}
"
)
def
compress
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
assert
module
.
config
.
_attn_implementation
!=
"eager"
,
"eager mode not supported"
if
self
.
streaming_mask
is
None
:
raise
ValueError
(
"Streaming mask not initialized. Make sure to call post_init_from_model to initialize this press."
)
k_len
=
keys
.
shape
[
2
]
if
(
self
.
head_compression_ratio
>
0
)
or
(
k_len
>
(
self
.
sink_size
+
self
.
recent_size
)):
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
masked_keys
=
torch
.
zeros_like
(
keys
[...,
0
],
dtype
=
torch
.
bool
)
masked_keys
[:,
self
.
streaming_mask
[
module
.
layer_idx
],
self
.
sink_size
:
-
self
.
recent_size
]
=
True
module
.
masked_key_indices
=
torch
.
nonzero
(
masked_keys
,
as_tuple
=
True
)
# Compute the compression ratio
self
.
compression_ratio_
=
self
.
streaming_mask
.
float
().
mean
().
item
()
self
.
compression_ratio_
*=
1
-
(
self
.
sink_size
+
self
.
recent_size
)
/
k_len
return
keys
,
values
@
staticmethod
@
cached
(
cache
,
key
=
lambda
model
:
model
.
config
.
name_or_path
)
def
load_attention_pattern
(
model
):
"""
Load the attention pattern from the DuoAttention repo
"""
assert
(
model
.
config
.
name_or_path
in
PATTERNS_DICT
),
f
"Checkpoint
{
model
.
config
.
name_or_path
}
not in
{
list
(
PATTERNS_DICT
.
keys
())
}
"
base_url
=
"https://raw.githubusercontent.com/mit-han-lab/duo-attention/refs/heads/main/attn_patterns"
url
=
f
"
{
base_url
}
/
{
PATTERNS_DICT
[
model
.
config
.
name_or_path
]
}
/"
# Load config
config
=
requests
.
get
(
url
+
"config.json"
).
json
()
# Load head scores and clip as in duo_attn.utils.load_attn_pattern
text
=
requests
.
get
(
url
+
"full_attention_heads.tsv"
).
text
head_scores
=
np
.
loadtxt
(
StringIO
(
text
),
dtype
=
float
,
delimiter
=
"
\t
"
)
head_scores
=
np
.
clip
(
head_scores
,
0
,
1
)
return
config
[
"sink_size"
],
config
[
"recent_size"
],
head_scores
@
cached
(
cache
,
key
=
lambda
model
,
num_samples
=
50
,
q_len
=
500
:
(
model
.
config
.
name_or_path
,
num_samples
,
q_len
))
def
duo_attention_on_the_fly
(
model
,
num_samples
=
50
,
q_len
=
500
):
"""
New experimental method to quickly compute DuoAttention scores:
- Compute the mean query and key on num_samples random samples from BookSum
- Repeat the mean query and key q_len times and apply RoPE to get (Q, K)
- Compute the attention weights for (Q[-1], K) and compute the "area under the cumulated attention curve"
These scores could also be saved to avoid recomputing them but this method is still experimental
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
.
config
.
name_or_path
)
num_heads
=
model
.
config
.
num_attention_heads
num_key_value_heads
=
model
.
config
.
num_key_value_heads
num_key_value_groups
=
num_heads
//
num_key_value_heads
# Load data
dataset
=
load_dataset
(
"kmfoda/booksum"
,
split
=
"train"
).
to_pandas
()
texts
=
dataset
.
sample
(
num_samples
,
random_state
=
42
)[
"chapter"
].
tolist
()
# Initialize variables
position_ids
=
torch
.
arange
(
q_len
).
unsqueeze
(
0
)
scores
=
torch
.
zeros
((
model
.
config
.
num_hidden_layers
,
num_key_value_heads
))
# Compute scores
for
text
in
texts
:
with
torch
.
no_grad
():
# Compute hidden states
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
model
.
device
)
hidden_states
=
list
(
model
(
**
inputs
,
output_hidden_states
=
True
).
hidden_states
[:
-
1
])
for
layer_idx
,
h
in
enumerate
(
hidden_states
):
module
=
model
.
model
.
layers
[
layer_idx
]
d
=
module
.
self_attn
.
head_dim
h
=
module
.
input_layernorm
(
h
)
# Mean query
q
=
module
.
self_attn
.
q_proj
(
h
)
q
=
q
.
view
(
1
,
q
.
shape
[
1
],
-
1
,
d
)
if
isinstance
(
module
,
(
Gemma3Attention
,
Qwen3Attention
)):
q
=
module
.
q_norm
(
q
)
q
=
q
.
mean
(
dim
=
1
,
keepdim
=
True
)
q
=
q
.
repeat
(
1
,
q_len
,
1
,
1
).
transpose
(
1
,
2
)
# Mean key
k
=
module
.
self_attn
.
k_proj
(
h
)
k
=
k
.
view
(
1
,
k
.
shape
[
1
],
-
1
,
d
)
if
isinstance
(
module
,
(
Gemma3Attention
,
Qwen3Attention
)):
k
=
module
.
k_norm
(
k
)
k
=
k
.
mean
(
dim
=
1
,
keepdim
=
True
)
k
=
k
.
repeat
(
1
,
q_len
,
1
,
1
).
transpose
(
1
,
2
)
# Apply RoPE
cos
,
sin
=
model
.
model
.
rotary_emb
(
h
,
position_ids
.
to
(
h
.
device
))
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
k
=
k
.
repeat_interleave
(
num_key_value_groups
,
dim
=
1
)
# Compute attention weights for the last token
attn_weights
=
torch
.
matmul
(
q
[:,
:,
-
1
:,
:],
k
.
transpose
(
2
,
3
))
/
(
d
**
0.5
)
attn_weights
=
attn_weights
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
).
squeeze
()
# Compute score: area under the cumulated attention curve
s
=
torch
.
cumsum
(
attn_weights
,
dim
=
1
).
mean
(
1
)
s
=
s
.
view
(
-
1
,
num_key_value_groups
).
mean
(
1
)
# Store the scores
scores
[
layer_idx
]
+=
s
.
cpu
()
/
num_samples
return
scores
.
numpy
()
kvpress/kvpress/presses/expected_attention_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
math
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
repeat_kv
from
kvpress.presses.scorer_press
import
ScorerPress
from
kvpress.utils
import
get_prerope_query_states
@
dataclass
class
ExpectedAttentionPress
(
ScorerPress
):
"""
Expected attention-based KV cache compression.
Computes importance scores based on expected attention that future queries
will pay to current key-value pairs. Uses statistical modeling of query
patterns and RoPE rotation matrices to predict future attention.
In particular:
1. Compute the mean and covariance matrix of the queries before RoPE.
2. Compute the RoPE rotation matrix R on next n_future_positions and average it
3. Apply R to the mean and covariance matrice of the queries.
4. As attention A = exp(Q @ K / sqrt(d)), we compute the expected attention
E(A) = exp(K @ mean.T / sqrt(d) + 1/2 K @ cov @ K.T / d)
5. Rescale the scores using (scores + epsilon) * ||V||_2
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
n_future_positions : int, default=512
Number of future positions to consider when computing expected attention.
n_sink : int, default=4
Number of initial tokens to exclude from compression (sink tokens).
Preserves first few tokens due to "sink attention" phenomenon where models
assign high attention to early tokens regardless of semantic importance.
use_covariance : bool, default=True
Whether to include covariance information in expected attention computation.
When True, uses both mean and covariance of query distributions for more
accurate but computationally expensive scoring. When False, uses only mean.
use_vnorm : bool, default=True
Whether to rescale scores using value vector norms.
Rescales expected attention scores by L2 norm of corresponding value vectors:
(scores + epsilon) * ||V||₂. Accounts for magnitude of attended information.
epsilon : float, default=0.0
Small constant added to scores before value norm rescaling for numerical stability.
"""
compression_ratio
:
float
=
0.0
n_future_positions
:
int
=
512
n_sink
:
int
=
4
use_covariance
:
bool
=
True
use_vnorm
:
bool
=
True
epsilon
:
float
=
0.0
def
get_query_statistics
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
):
"""
Compute the mean and covariance matrix of the queries
"""
q_len
=
hidden_states
.
shape
[
1
]
# Remove first hidden_states that likely contain outliers
h
=
hidden_states
[:,
self
.
n_sink
:]
query_states
=
get_prerope_query_states
(
module
,
h
)
# Query mean
mu
=
query_states
.
mean
(
dim
=
2
,
keepdim
=
True
)
# Query covariance
cov
=
None
if
self
.
use_covariance
:
centered_states
=
query_states
-
mu
cov
=
torch
.
einsum
(
"bnsi,bnsj->bnij"
,
centered_states
,
centered_states
)
/
h
.
shape
[
1
]
mu
=
mu
.
squeeze
(
2
)
# Apply RoPE to the mean and covariance matrix of the queries
mu
,
cov
=
self
.
apply_avg_rope
(
module
,
mu
,
cov
,
q_len
)
return
mu
,
cov
def
apply_avg_rope
(
self
,
module
:
nn
.
Module
,
mu
:
torch
.
Tensor
,
cov
:
torch
.
Tensor
,
q_len
:
int
):
"""
Apply average RoPE to the mean and covariance matrix of the queries
Parameters
----------
module : nn.Module
The module to apply RoPE to.
mu : torch.Tensor
The mean of the queries.
cov : torch.Tensor
The covariance matrix of the queries.
q_len : int
The length of the queries.
Returns
-------
mu : torch.Tensor
The mean of the queries after RoPE.
cov : torch.Tensor
The covariance matrix of the queries after RoPE.
"""
position_ids
=
torch
.
arange
(
q_len
,
q_len
+
self
.
n_future_positions
).
unsqueeze
(
0
).
to
(
mu
.
device
)
head_dim
=
module
.
head_dim
cos
,
sin
=
module
.
rotary_emb
(
mu
,
position_ids
)
cos
,
sin
=
cos
[
0
],
sin
[
0
]
Id
=
torch
.
eye
(
head_dim
,
device
=
cos
.
device
,
dtype
=
cos
.
dtype
)
P
=
torch
.
zeros
((
head_dim
,
head_dim
),
device
=
cos
.
device
,
dtype
=
cos
.
dtype
)
P
[
head_dim
//
2
:,
:
head_dim
//
2
],
P
[:
head_dim
//
2
,
head_dim
//
2
:]
=
torch
.
eye
(
head_dim
//
2
),
-
torch
.
eye
(
head_dim
//
2
)
R
=
cos
.
unsqueeze
(
1
)
*
Id
+
sin
.
unsqueeze
(
1
)
*
P
R
=
R
.
mean
(
dim
=
0
).
to
(
mu
.
device
)
mu
=
torch
.
matmul
(
mu
,
R
.
T
)
if
cov
is
not
None
:
cov
=
torch
.
matmul
(
R
,
torch
.
matmul
(
cov
,
R
.
T
))
return
mu
,
cov
def
score
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
,
)
->
torch
.
Tensor
:
# Remove sink tokens
assert
keys
.
size
(
2
)
>
self
.
n_sink
,
f
"Input should contain more tokens than n_sink=
{
self
.
n_sink
}
"
keys
=
keys
[:,
:,
self
.
n_sink
:]
values
=
values
[:,
:,
self
.
n_sink
:]
# Compute query statistics
mean_query
,
cov_query
=
self
.
get_query_statistics
(
module
,
hidden_states
)
# Compute scores
bsz
,
num_key_value_heads
,
q_len
,
d
=
keys
.
shape
num_key_value_groups
=
module
.
config
.
num_attention_heads
//
num_key_value_heads
keys
=
repeat_kv
(
keys
,
num_key_value_groups
).
transpose
(
2
,
3
)
scores
=
torch
.
matmul
(
mean_query
.
unsqueeze
(
2
),
keys
).
squeeze
(
2
)
/
math
.
sqrt
(
d
)
if
self
.
use_covariance
:
scores
+=
torch
.
einsum
(
"bhin, bhij, bhjn->bhn"
,
keys
,
cov_query
,
keys
)
/
d
/
2
scores
=
F
.
softmax
(
scores
,
dim
=-
1
)
# Average scores across groups
scores
=
scores
.
view
(
bsz
,
num_key_value_heads
,
num_key_value_groups
,
q_len
)
scores
=
scores
.
mean
(
dim
=
2
)
# Rescale scores by the norm of the values
if
self
.
use_vnorm
:
scores
=
(
scores
+
self
.
epsilon
)
*
values
.
norm
(
dim
=-
1
)
# Add back the sink tokens. Use max score to make sure they are not pruned.
scores
=
F
.
pad
(
scores
,
(
self
.
n_sink
,
0
),
value
=
scores
.
max
().
item
())
return
scores
kvpress/kvpress/presses/expected_attention_with_stats.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
importlib
import
os
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
import
fire
import
torch
from
datasets
import
load_dataset
from
huggingface_hub
import
PyTorchModelHubMixin
,
get_collection
from
torch
import
nn
from
tqdm
import
tqdm
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedModel
from
kvpress.presses.expected_attention_press
import
ExpectedAttentionPress
@
dataclass
class
ExpectedAttentionStatsPress
(
ExpectedAttentionPress
):
"""
Expected attention press that automatically loads pre-computed query statistics.
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
n_future_positions : int, default=512
Number of future positions to consider when computing expected attention.
n_sink : int, default=4
Number of initial tokens to exclude from compression (sink tokens).
use_covariance : bool, default=True
Whether to include covariance information in expected attention computation.
use_vnorm : bool, default=True
Whether to rescale scores using value vector norms.
epsilon : float, default=0.0
Small constant added to scores before value norm rescaling.
dataset_name : str, default="kmfoda/booksum"
Dataset used to compute the statistics.
num_samples : int, default=100
Number of samples used to compute the statistics.
sample_seq_len : int, default=1000
Sequence length used to compute the statistics.
"""
# Override parent defaults to enable stats by default
sample_seq_len
:
int
=
1000
num_samples
:
int
=
100
dataset_name
:
str
=
"kmfoda/booksum"
stats_folder
:
Optional
[
str
]
=
None
mu
:
torch
.
Tensor
=
field
(
init
=
False
,
default
=
None
)
# initialized in post_init_from_model
cov
:
torch
.
Tensor
=
field
(
init
=
False
,
default
=
None
)
# initialized in post_init_from_model
def
get_query_statistics
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
):
"""
Override the parent method to use the pre-computed query statistics.
"""
q_len
=
hidden_states
.
shape
[
1
]
layer_idx
=
module
.
layer_idx
mu
,
cov
=
self
.
apply_avg_rope
(
module
,
self
.
mu
[
layer_idx
],
self
.
cov
[
layer_idx
],
q_len
)
# type: ignore
return
mu
.
unsqueeze
(
0
),
cov
.
unsqueeze
(
0
)
@
staticmethod
def
available_stats
():
collection
=
get_collection
(
"alessiodevoto/expectedattentionstats-68b0248d519303713320e2cf"
)
return
[
x
.
item_id
for
x
in
collection
.
items
]
def
post_init_from_model
(
self
,
model
):
"""
Automatically load or compute query statistics for the model.
"""
if
self
.
mu
is
None
and
self
.
cov
is
None
:
if
self
.
stats_folder
is
not
None
:
stats
=
ExpectedAttentionStats
.
from_pretrained
(
self
.
stats_folder
)
else
:
stats
=
self
.
_maybe_load_stats_from_hub
(
model
)
self
.
mu
=
stats
.
query_mean
.
data
.
to
(
model
.
device
,
dtype
=
model
.
dtype
)
self
.
cov
=
stats
.
query_cov
.
data
.
to
(
model
.
device
,
dtype
=
model
.
dtype
)
def
_maybe_load_stats_from_hub
(
self
,
model
:
PreTrainedModel
):
"""Load statistics from the Hugging Face Hub."""
stats_id
=
ExpectedAttentionStats
(
model_name
=
model
.
config
.
name_or_path
,
num_layers
=
model
.
config
.
num_hidden_layers
,
num_heads
=
model
.
config
.
num_attention_heads
,
head_dim
=
model
.
config
.
head_dim
,
dataset_name
=
self
.
dataset_name
,
num_samples
=
self
.
num_samples
,
sample_seq_len
=
self
.
sample_seq_len
,
n_sink
=
self
.
n_sink
,
).
stats_id
()
try
:
return
ExpectedAttentionStats
.
from_pretrained
(
stats_id
)
except
ValueError
:
raise
ValueError
(
f
"No statistics found for model
{
stats_id
}
on the Hub. Please compute them first. "
"You can do so by running the following code: "
"```"
"python expected_attention_with_stats.py --model_name <model_name>"
"```"
)
class
ExpectedAttentionStats
(
torch
.
nn
.
Module
,
PyTorchModelHubMixin
):
"""
Module that stores the mean and covariance matrix of the queries, possibly uploaded to the HF hub.
"""
def
__init__
(
self
,
num_layers
:
int
,
num_heads
:
int
,
head_dim
:
int
,
dataset_name
:
str
,
model_name
:
str
,
num_samples
:
int
,
sample_seq_len
:
int
,
n_sink
:
int
,
):
super
().
__init__
()
self
.
query_mean
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_layers
,
num_heads
,
head_dim
))
self
.
query_cov
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_layers
,
num_heads
,
head_dim
,
head_dim
))
self
.
dataset_name
=
dataset_name
self
.
model_name
=
model_name
self
.
num_samples
=
num_samples
self
.
sample_seq_len
=
sample_seq_len
self
.
n_sink
=
n_sink
def
stats_id
(
self
)
->
str
:
"""Generate the statistics ID for the model and configuration."""
return
f
"alessiodevoto/exp_att_stats_
{
self
.
model_name
.
replace
(
'/'
,
'_'
)
}
_
{
self
.
dataset_name
.
replace
(
'/'
,
'_'
)
}
_
{
self
.
num_samples
}
_
{
self
.
sample_seq_len
}
_
{
self
.
n_sink
}
"
# noqa: E501
# The code below is used to collect statistics on a dataset.
@
contextmanager
def
patch_rotary_embedding
(
model
):
"""
A context manager to dynamically patch the `apply_rotary_pos_emb` function
for any supported model architecture. It captures the query states before
rotary embeddings are applied.
Args:
model (PreTrainedModel): The transformer model instance.
Yields:
list: A list that will be populated with the captured query tensors.
"""
# Dynamically find the model's specific "modeling" module
try
:
module_path
=
model
.
__class__
.
__module__
modeling_module
=
importlib
.
import_module
(
module_path
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to import module for
{
model
.
__class__
.
__name__
}
:
{
e
}
"
)
# Check for the target function and save the original
target_function
=
"apply_rotary_pos_emb"
if
not
hasattr
(
modeling_module
,
target_function
):
raise
AttributeError
(
f
"Model architecture '
{
model
.
config
.
model_type
}
' is not supported. "
f
"The module '
{
module_path
}
' does not contain '
{
target_function
}
'."
)
original_function
=
getattr
(
modeling_module
,
target_function
)
captured_tensors
=
[]
def
patched_function
(
q_embed
,
k_embed
,
*
args
,
**
kwargs
):
# Capture the query tensor before RoPE is applied
captured_tensors
.
append
(
q_embed
.
detach
().
cpu
())
q_embed
,
k_embed
=
original_function
(
q_embed
,
k_embed
,
*
args
,
**
kwargs
)
return
q_embed
,
k_embed
# Apply the patch
setattr
(
modeling_module
,
target_function
,
patched_function
)
try
:
yield
captured_tensors
finally
:
setattr
(
modeling_module
,
target_function
,
original_function
)
@
torch
.
inference_mode
()
def
collect_queries
(
model
:
PreTrainedModel
,
dataset_name
:
str
,
num_samples
:
int
,
sample_seq_len
:
int
,
n_sink
:
int
,
text_column
:
str
=
"chapter"
,
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Collects query representations from a transformer model using a calibration dataset.
This function runs the model on a small number of samples from the "kmfoda/booksum" dataset,
capturing the query tensors after rotary positional embeddings are applied. It trims the
input text to a maximum length (`q_len`), skips the first `n_sink` tokens (to avoid outliers),
and returns the collected queries.
Args:
model (PreTrainedModel): The transformer model instance.
dataset_name (str): Name of the dataset to use for collecting statistics.
num_samples (int): Number of samples to use from the calibration dataset.
q_len (int): Maximum sequence length to consider for each sample.
n_sink (int): Number of initial tokens to exclude from the collected queries.
text_column (str): Name of the column in the dataset containing the text to tokenize.
Returns:
list or tuple:
collected_queries (list): List of query tensors, each of shape (num_layers, num_heads, seq_len, head_dim)
mean_query (torch.Tensor): Mean query vector for each layer and head.
cov_query (torch.Tensor): Covariance matrix of queries for each layer and head.
"""
# Load dataset and tokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
.
config
.
name_or_path
)
dataset
=
load_dataset
(
dataset_name
,
split
=
f
"train[:
{
num_samples
}
]"
)
# Cut to max q_len
dataset
=
dataset
.
map
(
lambda
x
:
{
text_column
:
x
[
text_column
][:
sample_seq_len
]})
collected_queries
=
[]
for
text
in
tqdm
(
dataset
[
text_column
],
desc
=
"Collecting queries"
):
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
to
(
model
.
device
)
with
patch_rotary_embedding
(
model
)
as
captured_queries
:
model
(
**
inputs
)
collected_queries
.
append
(
torch
.
cat
(
captured_queries
,
dim
=
0
)[:,
:,
n_sink
:,
:])
cat_queries
=
torch
.
cat
(
collected_queries
,
dim
=-
2
)
mean_query
=
cat_queries
.
mean
(
dim
=-
2
)
# compute covariance manually
centered_queries
=
cat_queries
-
mean_query
.
unsqueeze
(
-
2
)
N
=
cat_queries
.
shape
[
-
2
]
cov_query
=
(
centered_queries
.
transpose
(
-
2
,
-
1
)
@
centered_queries
)
/
(
N
-
1
)
return
collected_queries
,
mean_query
,
cov_query
def
main
(
model_name
:
str
=
"meta-llama/Llama-3.1-8B-Instruct"
,
output_path
:
str
=
"."
,
dataset_name
:
str
=
"kmfoda/booksum"
,
num_samples
:
int
=
100
,
sample_seq_len
:
int
=
1000
,
n_sink
:
int
=
4
,
text_column
:
str
=
"chapter"
,
device_map
:
str
=
"auto"
,
):
"""
Collect query statistics for a transformer model and save them.
Args:
model_name: Name of the model to collect statistics for
output_path: Directory to save the statistics
dataset_name: Dataset to use for collecting statistics
num_samples: Number of samples to use from the dataset
sample_seq_len: Sequence length for each sample
n_sink: Number of initial tokens to exclude
text_column: Column name containing text in the dataset
device_map: Device mapping for the model
"""
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
device_map
,
dtype
=
torch
.
bfloat16
).
eval
()
_
,
mu
,
cov
=
collect_queries
(
model
,
dataset_name
,
num_samples
,
sample_seq_len
,
n_sink
,
text_column
)
stats
=
ExpectedAttentionStats
(
num_layers
=
model
.
config
.
num_hidden_layers
,
num_heads
=
model
.
config
.
num_attention_heads
,
head_dim
=
model
.
config
.
head_dim
,
dataset_name
=
dataset_name
,
model_name
=
model_name
,
num_samples
=
num_samples
,
sample_seq_len
=
sample_seq_len
,
n_sink
=
n_sink
,
)
stats
.
query_mean
.
data
=
mu
stats
.
query_cov
.
data
=
cov
output_path
=
os
.
path
.
join
(
output_path
,
stats
.
stats_id
())
stats
.
save_pretrained
(
output_path
)
print
(
f
"Statistics saved to:
{
output_path
}
"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
main
)
kvpress/kvpress/presses/fastkvzip_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
math
import
os
import
re
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
typing
import
Generator
import
torch
from
huggingface_hub
import
hf_hub_download
from
torch
import
nn
from
transformers
import
AutoConfig
,
Gemma3ForConditionalGeneration
,
PreTrainedModel
from
transformers.models.qwen3.modeling_qwen3
import
Qwen3RMSNorm
from
kvpress.presses.base_press
import
SUPPORTED_MODELS
,
BasePress
logger
=
logging
.
getLogger
(
__name__
)
class
FastKVzipGate
(
nn
.
Module
):
"""
Fast KVzip gate architecture (https://arxiv.org/abs/2601.17668).
"""
def
__init__
(
self
,
index
:
int
,
input_dim
:
int
,
nhead
:
int
,
ngroup
:
int
,
dtype
:
torch
.
dtype
,
output_dim
:
int
=
16
,
sink
:
int
=
16
,
):
super
().
__init__
()
self
.
index
=
index
self
.
output_dim
=
output_dim
self
.
nhead
=
nhead
self
.
ngroup
=
ngroup
self
.
sink
=
sink
self
.
q_proj
=
nn
.
Linear
(
input_dim
,
nhead
*
ngroup
*
output_dim
,
bias
=
True
,
dtype
=
dtype
)
self
.
k_proj
=
nn
.
Linear
(
input_dim
,
nhead
*
output_dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
q_norm
=
Qwen3RMSNorm
(
output_dim
)
self
.
k_norm
=
Qwen3RMSNorm
(
output_dim
)
self
.
k_base
=
nn
.
Parameter
(
torch
.
zeros
([
nhead
,
1
,
sink
,
output_dim
]))
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
([
nhead
,
1
,
ngroup
],
dtype
=
dtype
))
self
.
d
=
math
.
sqrt
(
self
.
output_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
):
hidden_states
=
hidden_states
.
squeeze
(
0
)
# bsz = 1
nseq
=
hidden_states
.
shape
[
0
]
# sequence x dim
hidden_shape
=
(
nseq
,
self
.
nhead
,
-
1
,
self
.
output_dim
)
queries
=
self
.
q_norm
(
self
.
q_proj
(
hidden_states
).
view
(
hidden_shape
))
keys
=
self
.
k_norm
(
self
.
k_proj
(
hidden_states
).
view
(
hidden_shape
))
queries
=
queries
.
transpose
(
0
,
1
).
transpose
(
-
1
,
-
2
)
keys
=
keys
.
transpose
(
0
,
1
)
# head x seq x 1 x group
logit
=
torch
.
matmul
(
keys
,
queries
)
/
self
.
d
+
self
.
b
.
unsqueeze
(
2
)
# head x 1 x sink x group
logit_base
=
torch
.
matmul
(
self
.
k_base
,
queries
)
/
self
.
d
score
=
1
/
(
1
+
torch
.
exp
(
logit_base
-
logit
).
sum
(
2
,
keepdim
=
True
))
score
=
score
.
mean
(
-
1
)
# n_head, seq, 1
return
score
.
squeeze
(
-
1
).
unsqueeze
(
0
)
# bsz x n_head x seq
def
extra_repr
(
self
):
# Customize the print output
repr_str
=
f
"index=
{
self
.
index
}
, output_dim=
{
self
.
output_dim
}
, nhead=
{
self
.
nhead
}
, ngroup=
{
self
.
ngroup
}
\n
"
if
self
.
sink
!=
0
:
repr_str
+=
f
"k_base shape:
{
self
.
k_base
.
shape
}
\n
"
repr_str
+=
f
"b shape:
{
self
.
b
.
shape
}
\n
"
return
repr_str
def
load_fastkvzip
(
model_name
:
str
=
"Qwen/Qwen3-8B"
,
device
:
str
=
"cuda"
):
"""Load trained gate weights"""
if
not
model_name
:
raise
AssertionError
(
"Model_name is empty. Please check load_gate."
)
state_dict
,
gate_id
=
get_gate_weight
(
model_name
)
dtype
=
state_dict
[
0
][
"q_proj.weight"
].
dtype
head_group_outdim
,
input_dim
=
state_dict
[
0
][
"q_proj.weight"
].
shape
head_outdim
,
_
=
state_dict
[
0
][
"k_proj.weight"
].
shape
output_dim
=
state_dict
[
0
][
"q_norm.weight"
].
shape
[
-
1
]
nhead
=
head_outdim
//
output_dim
ngroup
=
head_group_outdim
//
head_outdim
m
=
re
.
search
(
r
"sink(\d+)"
,
gate_id
)
sink
=
int
(
m
.
group
(
1
))
if
m
else
0
modules
=
[]
for
idx
,
weight
in
enumerate
(
state_dict
):
module
=
FastKVzipGate
(
idx
,
input_dim
,
nhead
,
ngroup
,
dtype
,
output_dim
,
sink
).
to
(
device
)
module
.
load_state_dict
(
weight
)
modules
.
append
(
module
)
print
(
f
"load gate
{
gate_id
}
(
{
module
}
)"
)
return
modules
def
get_gate_id
(
model_name
:
str
):
"""Get the gate id from model names"""
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
hasattr
(
config
,
"text_config"
):
config
=
config
.
text_config
ngroup
=
config
.
num_attention_heads
//
config
.
num_key_value_heads
file_name
=
f
"q
{
ngroup
}
_dim16_sink16"
model_name
=
model_name
.
split
(
"/"
)[
-
1
].
lower
()
gate_id
=
os
.
path
.
join
(
model_name
,
file_name
+
".pt"
)
return
gate_id
def
get_gate_weight
(
model_name
:
str
):
"""Load trained gate weights from HuggingFace"""
gate_id
=
get_gate_id
(
model_name
)
file_path
=
hf_hub_download
(
repo_id
=
"Jang-Hyun/Fast-KVzip"
,
filename
=
gate_id
,
repo_type
=
"model"
)
# Load the PyTorch tensor/dictionary
weights
=
torch
.
load
(
file_path
,
weights_only
=
False
)[
"module"
]
return
weights
,
gate_id
@
dataclass
class
FastKVzipPress
(
BasePress
):
"""
Fast KVzip estimates KV importance scores using gates trained on KVzip scores.
In this code, we implement Fast KVzip with minimal changes to this repository.
For a fully optimized implementation with actual compression and chunked-prefill,
please refer to the original repository (https://github.com/Janghyun1230/FastKVzip).
Based on Fast KVzip (https://arxiv.org/abs/2601.17668).
Authors: Jang-Hyun Kim, Dongyoon Han, Sangdoo Yun
Affiliation: NAVER AI Lab
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
layerwise : bool, default=False
Whether to enable uniform compression ratios across layers.
When False, while the overall KV cache compression ratio is maintained,
each layer has a different compression ratio.
n_sink : int, default=4
Number of initial tokens to preserve as attention sinks.
window_size : int, default=4096
Number of tokens in the local window retained during chunked prefilling.
window_ratio : float, default=0.02
Fraction of the context length used to calculate the local window size retained during short-context prefilling.
"""
compression_ratio
:
float
=
0.0
layerwise
:
bool
=
False
n_sink
:
int
=
4
window_size
:
int
=
4096
# for chunked prefilling with long contexts
window_ratio
:
float
=
0.02
gates
:
list
[
nn
.
Module
]
|
None
=
field
(
init
=
False
,
default
=
None
)
score_val
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
|
None
=
field
(
init
=
False
,
default
=
None
)
def
post_init_from_model
(
self
,
model
):
"""
Automatically load gates for the model.
"""
if
self
.
gates
is
None
:
try
:
self
.
gates
=
load_fastkvzip
(
model_name
=
model
.
config
.
name_or_path
,
device
=
model
.
device
)
except
Exception
as
e
:
raise
RuntimeError
(
"The gates for the given model are not released! "
"Please check the available models at: "
"https://huggingface.co/Jang-Hyun/Fast-KVzip/tree/main"
)
from
e
@
contextmanager
def
__call__
(
self
,
model
:
PreTrainedModel
)
->
Generator
:
"""
Context manager that handles both initial prefilling and Fast KVzip scoring/compression.
This overrides the base class __call__ method to implement the Fast KVzip algorithm:
1. First yield: allows initial prefilling with context and KV importance scoring via gates
2. After yield: performs KV eviction based on the importance scores
"""
if
not
isinstance
(
model
,
SUPPORTED_MODELS
):
logger
.
warning
(
f
"Model
{
type
(
model
)
}
not tested, supported models:
{
SUPPORTED_MODELS
}
"
)
self
.
post_init_from_model
(
model
)
hooks
=
[]
try
:
self
.
score_val
=
[
None
for
_
in
range
(
len
(
model
.
model
.
layers
))]
# reset every prefilling
language_model
=
model
.
model
.
language_model
if
hasattr
(
model
.
model
,
"language_model"
)
else
model
.
model
for
layer
in
language_model
.
layers
:
if
isinstance
(
model
,
Gemma3ForConditionalGeneration
)
and
layer
.
self_attn
.
is_sliding
:
# Skip layers with sliding window attention, only for Gemma3
continue
layer
.
self_attn
.
rotary_emb
=
language_model
.
rotary_emb
hooks
.
append
(
layer
.
self_attn
.
register_forward_hook
(
self
.
forward_hook
,
with_kwargs
=
True
))
yield
self
.
compress_post
(
model
)
# Perform compression
finally
:
for
hook
in
hooks
:
hook
.
remove
()
def
forward_hook
(
self
,
module
:
nn
.
Module
,
input
:
list
[
torch
.
Tensor
],
kwargs
:
dict
,
output
:
list
):
"""
Override the forward_hook of BasePress.
During the forward_hook, Fast KVzip calculates importance scores,
aggregates scores across all layers, and then performs compression.
"""
hidden_states
=
kwargs
[
"hidden_states"
]
q_len
=
hidden_states
.
shape
[
1
]
# Don't compress after pre-filling
if
kwargs
[
"cache_position"
][
-
1
]
>
q_len
:
return
output
self
.
_score_fast
(
module
,
hidden_states
)
return
output
def
_score_fast
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
):
"""
Calculate the KV importance scores.
"""
layer_idx
=
int
(
module
.
layer_idx
)
self
.
gates
[
layer_idx
]
=
self
.
gates
[
layer_idx
].
to
(
hidden_states
.
device
)
scores
=
self
.
gates
[
layer_idx
](
hidden_states
)
scores
[:,
:,
:
self
.
n_sink
]
=
1.0
ctx_len
=
scores
.
size
(
-
1
)
if
ctx_len
<
32000
:
window_size
=
int
(
ctx_len
*
self
.
window_ratio
)
else
:
window_size
=
self
.
window_size
scores
[:,
:,
-
window_size
:]
=
1.0
self
.
score_val
[
layer_idx
]
=
scores
def
compress_post
(
self
,
model
:
PreTrainedModel
):
"""
Obtain the indices of KV pairs to be evicted.
Adopted from adakv_press.compress (fake compression). KVzip does not rely on safeguards.
"""
self
.
score_val
=
torch
.
stack
(
self
.
score_val
,
dim
=
0
)
if
self
.
compression_ratio
>
0
:
n_layer
,
bsz
,
num_key_value_heads
,
ctx_len
=
self
.
score_val
.
shape
# calculate the pruned KV pairs across layers
if
self
.
layerwise
:
nl
=
int
(
bsz
*
num_key_value_heads
*
ctx_len
*
self
.
compression_ratio
)
n_pruned_layers
=
nl
*
torch
.
ones
(
n_layer
,
device
=
self
.
score_val
.
device
,
dtype
=
torch
.
int
)
else
:
n_pruned_indices
=
int
(
self
.
score_val
.
numel
()
*
self
.
compression_ratio
)
pruned_indices
=
torch
.
topk
(
-
self
.
score_val
.
reshape
(
-
1
),
n_pruned_indices
).
indices
n_tokens_per_layer
=
bsz
*
num_key_value_heads
*
ctx_len
n_pruned_layers
=
torch
.
bincount
(
pruned_indices
//
n_tokens_per_layer
,
minlength
=
n_layer
).
int
()
for
layer
in
model
.
model
.
layers
:
module
=
layer
.
self_attn
layer_idx
=
int
(
module
.
layer_idx
)
assert
module
.
config
.
_attn_implementation
!=
"eager"
,
"eager mode not supported"
scores
=
self
.
score_val
[
layer_idx
]
# Compute bottom-k across heads
n_pruned
=
n_pruned_layers
[
layer_idx
].
cpu
()
indices
=
torch
.
topk
(
-
scores
.
reshape
(
bsz
,
-
1
),
n_pruned
,
dim
=
1
).
indices
.
flatten
().
cpu
()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for details
batch_indices
=
torch
.
arange
(
bsz
,
device
=
n_pruned
.
device
).
repeat_interleave
(
n_pruned
)
head_indices
=
indices
//
ctx_len
seq_indices
=
indices
%
ctx_len
module
.
masked_key_indices
=
(
batch_indices
,
head_indices
,
seq_indices
)
kvpress/kvpress/presses/finch_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
import
torch
from
torch.nn
import
functional
as
F
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.key_rerotation_press
import
KeyRerotationPress
from
kvpress.presses.snapkv_press
import
SnapKVPress
@
dataclass
class
FinchPress
(
BasePress
):
"""
FINCH: Prompt-guided Key-Value Cache Compression.
SnapKV-style compression with dynamic window sizing based on delimiter tokens.
Requires input format: `context + delimiter_token + question`. The delimiter
separates context from query, allowing dynamic window size determination.
Use `update_model_and_tokenizer` method to set delimiter token before use.
Based on FINCH (https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
chunk_length : int, optional
Length of chunks for optional chunked compression. None processes entire context at once.
normalize_scores : bool, default=True
Whether to normalize attention scores by number of non-zero weights.
rerotate_keys : bool, default=True
Whether to rerotate keys after compression using RoPE for proper positional encoding.
delimiter_token : str
Delimiter token string separating context from query (set automatically).
delimiter_token_id : int
Token ID for delimiter token (set automatically).
window_size : int
Dynamically determined window size based on delimiter position (set automatically).
"""
compression_ratio
:
float
=
0.0
chunk_length
:
int
=
None
normalize_scores
:
bool
=
True
rerotate_keys
:
bool
=
True
delimiter_token
:
str
=
field
(
default
=
None
,
init
=
False
)
delimiter_token_id
:
int
=
field
(
default
=
None
,
init
=
False
)
window_size
:
int
=
field
(
default
=
None
,
init
=
False
)
def
score
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
"""
Similar to SnapKVPress except it adds a normalization step before averaging on the context window.
"""
bsz
,
num_key_value_heads
,
k_len
,
_
=
keys
.
shape
num_key_value_groups
=
module
.
config
.
num_attention_heads
//
num_key_value_heads
if
attentions
is
not
None
:
attn_weights
=
attentions
[...,
-
self
.
window_size
:,
:
-
self
.
window_size
]
else
:
attn_weights
=
SnapKVPress
.
compute_window_attention
(
module
,
hidden_states
,
keys
,
self
.
window_size
,
kwargs
[
"position_embeddings"
]
)
if
self
.
normalize_scores
:
non_zero_counts
=
torch
.
arange
(
k_len
-
self
.
window_size
,
k_len
)[
None
,
None
,
:,
None
]
non_zero_counts
=
non_zero_counts
.
to
(
attn_weights
.
device
)
attn_weights
=
attn_weights
*
non_zero_counts
# Average per group
scores
=
attn_weights
.
mean
(
dim
=-
2
)
scores
=
scores
.
view
(
bsz
,
num_key_value_heads
,
num_key_value_groups
,
k_len
-
self
.
window_size
)
scores
=
scores
.
mean
(
dim
=
2
)
# Add back the observation window. Use max score to make sure the window is not pruned.
scores
=
F
.
pad
(
scores
,
(
0
,
self
.
window_size
),
value
=
scores
.
max
().
item
())
return
scores
def
compress
(
self
,
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
):
"""
Scores are computed by chunks, keys and values are then compressed and re-rotated.
"""
if
self
.
compression_ratio
==
0
:
return
keys
,
values
assert
self
.
window_size
is
not
None
,
"window_size must be provided"
# Compute scores
scores
=
self
.
score
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
# Compute indices to keep (optionally by chunks)
k_len
=
keys
.
shape
[
2
]
# Use actual sequence length from keys instead of hidden_states
if
self
.
chunk_length
is
None
:
n_kept
=
int
(
k_len
*
(
1
-
self
.
compression_ratio
))
indices
=
scores
.
topk
(
n_kept
,
dim
=-
1
).
indices
else
:
assert
self
.
chunk_length
>
self
.
window_size
/
(
1
-
self
.
compression_ratio
)
indices
=
[]
for
i
in
range
(
0
,
k_len
,
self
.
chunk_length
):
chunk_scores
=
scores
[:,
:,
i
:
i
+
self
.
chunk_length
]
n_kept
=
max
(
1
,
int
(
chunk_scores
.
shape
[
2
]
*
(
1
-
self
.
compression_ratio
)))
chunk_indices
=
i
+
chunk_scores
.
topk
(
n_kept
,
dim
=-
1
).
indices
indices
.
append
(
chunk_indices
)
indices
=
torch
.
cat
(
indices
,
dim
=-
1
)
if
self
.
rerotate_keys
:
indices
=
torch
.
sort
(
indices
,
dim
=
2
).
values
keys
=
KeyRerotationPress
.
rerotate_keys
(
module
,
indices
,
keys
)
indices
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
module
.
head_dim
)
else
:
indices
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
module
.
head_dim
)
keys
=
keys
.
gather
(
2
,
indices
).
contiguous
()
values
=
values
.
gather
(
2
,
indices
).
contiguous
()
return
keys
,
values
def
embed_token_forward_hook
(
self
,
module
,
input
,
output
):
"""
Forward hook to detect a delimiter token between the context and the window
"""
if
input
[
0
].
shape
[
1
]
>
1
and
self
.
delimiter_token_id
in
input
[
0
][
0
]:
# prefilling
assert
len
(
input
[
0
])
==
1
,
"Only batch size 1 is supported."
# Find the delimiter token and compute the window size
delim_tokens
=
input
[
0
][
0
]
==
self
.
delimiter_token_id
assert
delim_tokens
.
sum
()
==
1
,
"Only one delimiter token should be present."
context_length
=
int
(
torch
.
nonzero
(
delim_tokens
)[
0
].
item
())
self
.
window_size
=
len
(
input
[
0
][
0
])
-
1
-
context_length
assert
self
.
window_size
>
0
,
"No window detected (window size must be > 0)."
# Remove the delimiter token from the output
output
=
output
[:,
~
delim_tokens
]
return
output
def
update_model_and_tokenizer
(
self
,
model
,
tokenizer
,
delimiter_token
:
str
=
"<|finch_sep|>"
):
"""
Set the delimiter token and update the tokenizer accordingly.
This method should be called before calling the press.
"""
self
.
delimiter_token
=
delimiter_token
if
delimiter_token
not
in
tokenizer
.
get_vocab
():
tokenizer
.
add_special_tokens
({
"additional_special_tokens"
:
[
delimiter_token
]})
self
.
delimiter_token_id
=
tokenizer
.
convert_tokens_to_ids
(
delimiter_token
)
# type: ignore
# update model embeddings
model
.
resize_token_embeddings
(
len
(
tokenizer
))
return
tokenizer
@
contextmanager
def
__call__
(
self
,
model
):
# The user should set the delimiter_token_id before calling the press.
if
self
.
delimiter_token_id
is
None
:
raise
ValueError
(
"""No delimiter token ID provided.
Use the update_model_and_tokenizer method before calling the press."""
)
with
super
().
__call__
(
model
):
try
:
hook
=
model
.
model
.
embed_tokens
.
register_forward_hook
(
self
.
embed_token_forward_hook
)
yield
finally
:
hook
.
remove
()
kvpress/kvpress/presses/key_rerotation_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
transformers.models.llama.modeling_llama
import
rotate_half
from
kvpress.presses.base_press
import
BasePress
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
KeyRerotationPress
(
BasePress
):
"""
Key Rerotation: RoPE-aware compression wrapper for maintaining positional encoding.
Enhances any ScorerPress by applying key rerotation after compression to maintain
proper RoPE (Rotary Position Embedding) representations. When tokens are pruned,
remaining tokens need positional encodings adjusted for their new positions.
This method is used in several key-value cache compression methods, such as
- SinkCache implementation in Hugging Face's transformers library
- FINCH: Prompt-guided Key-Value Cache Compression for Large Language Models
Parameters
----------
press : ScorerPress
The underlying scoring method to enhance with key rerotation.
Rerotation is applied after the press determines which tokens to keep.
"""
press
:
ScorerPress
def
__post_init__
(
self
):
assert
isinstance
(
self
.
press
,
ScorerPress
)
def
post_init_from_model
(
self
,
model
):
self
.
press
.
post_init_from_model
(
model
)
@
property
def
compression_ratio
(
self
):
return
self
.
press
.
compression_ratio
@
compression_ratio
.
setter
def
compression_ratio
(
self
,
value
):
self
.
press
.
compression_ratio
=
value
@
staticmethod
def
_rerotate_cos_sin
(
x
,
inv_freq
,
selected_positions
):
"""
Compute cosine and sine rotary positional embeddings required to
re-rotate pruned keys back into the canonical RoPE space.
Parameters
----------
x : torch.Tensor
Any key-like tensor that provides ``dtype`` and ``device``.
Shape ``(bsz, num_key_value_heads, q_len, d)``.
inv_freq : torch.Tensor
``module.rotary_emb.inv_freq``. Shape ``(d//2,)``.
selected_positions : torch.Tensor
Indices of the *kept* tokens.
Shape ``(bsz, num_key_value_heads, n_kept)``.
Returns
-------
cos, sin : torch.Tensor
Cosine and sine embeddings, each of shape
``(bsz, num_key_value_heads, n_kept, d)``, matching ``dtype``/``device`` of ``x``.
"""
bsz
,
num_key_value_heads
,
n_kept
=
selected_positions
.
shape
device
=
selected_positions
.
device
device_type
=
x
.
device
.
type
dtype
=
x
.
dtype
# Original positional indices
idx
=
torch
.
arange
(
0
,
n_kept
,
device
=
device
)
# (n_kept,)
idx
=
idx
.
unsqueeze
(
0
)
# (1, n_kept)
inv_freq
=
inv_freq
[
None
,
None
,
:,
None
].
float
().
expand
(
bsz
,
num_key_value_heads
,
-
1
,
1
)
idx
=
idx
[:,
None
,
:].
float
().
expand
(
bsz
,
num_key_value_heads
,
n_kept
)
# Compute delta between original and selected positions
delta_pos
=
idx
-
selected_positions
# (bsz, num_key_value_heads, n_kept)
delta_pos
=
delta_pos
.
unsqueeze
(
2
)
# (bsz, num_key_value_heads, 1, n_kept)
device_type
=
device_type
if
isinstance
(
device_type
,
str
)
and
device_type
!=
"mps"
else
"cpu"
with
torch
.
autocast
(
device_type
=
device_type
,
enabled
=
False
):
# Compute the new freq by scaling inv_freq by delta
freqs
=
delta_pos
.
float
()
*
inv_freq
.
float
()
# (bsz, num_key_value_heads, d//2, n_kept)
freqs
=
freqs
.
transpose
(
2
,
3
)
# (bsz, num_key_value_heads, n_kept, d//2)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
# Compute cosine and sine required to re-rotate keys to selected positions
cos
=
emb
.
cos
().
contiguous
()
sin
=
emb
.
sin
().
contiguous
()
return
cos
.
to
(
dtype
=
dtype
),
sin
.
to
(
dtype
=
dtype
)
@
staticmethod
def
rerotate_keys
(
module
:
nn
.
Module
,
indices
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Rerotate keys to have a uniform RoPE representation of keys after pruning.
Parameters
----------
module : nn.Module
The model module containing the rotary embedding.
indices : torch.Tensor
Indices of the kept tokens after pruning.
keys : torch.Tensor
The keys tensor to be rerotated.
Returns
-------
torch.Tensor
The rerotated keys tensor of shape
``(bsz, num_heads, n_kept, d)``.
"""
new_cos
,
new_sin
=
KeyRerotationPress
.
_rerotate_cos_sin
(
keys
,
module
.
rotary_emb
.
inv_freq
,
indices
)
indices
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
module
.
head_dim
)
keys
=
keys
.
gather
(
2
,
indices
).
contiguous
()
return
(
keys
*
new_cos
)
+
(
rotate_half
(
keys
)
*
new_sin
)
def
compress
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
:
dict
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
press
.
compression_ratio
==
0
:
return
keys
,
values
# Compute scores from base press
scores
=
self
.
press
.
score
(
module
,
hidden_states
,
keys
,
values
,
attentions
,
kwargs
)
# Get indices of KV pairs with the lowest scores
q_len
=
keys
.
shape
[
2
]
n_kept
=
int
(
q_len
*
(
1
-
self
.
press
.
compression_ratio
))
indices
=
scores
.
topk
(
n_kept
,
dim
=-
1
).
indices
indices
=
torch
.
sort
(
indices
,
dim
=
2
).
values
keys
=
self
.
rerotate_keys
(
module
,
indices
,
keys
)
indices
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
module
.
head_dim
)
values
=
values
.
gather
(
2
,
indices
).
contiguous
()
return
keys
,
values
kvpress/kvpress/presses/keydiff_press.py
0 → 100644
View file @
c16d506e
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
kvpress.presses.scorer_press
import
ScorerPress
@
dataclass
class
KeyDiffPress
(
ScorerPress
):
"""
KeyDiff: Key similarity-based KV cache compression.
Evicts tokens based on key vector similarity to average key pattern.
Identifies tokens with most similar keys to average and removes them,
keeping tokens with more distinctive key vectors.
Based on KeyDiff (https://arxiv.org/abs/2504.15364).
Note: The original press in the KeyDiff paper implements a block-wise iterative compression.
In KVPress, the iterative compression is implemented in the BlockPress class.
Therefore, to replicate the paper's implementation, please use:
`press = BlockPress(press=KeyDiffPress(compression_ratio=0.x), block_size=N)`
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
def
score
(
self
,
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
attentions
:
torch
.
Tensor
,
kwargs
,
)
->
torch
.
Tensor
:
anchor
=
F
.
normalize
(
keys
,
p
=
2
,
dim
=-
1
).
mean
(
dim
=
2
,
keepdim
=
True
)
return
-
F
.
cosine_similarity
(
keys
,
anchor
,
dim
=-
1
)
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment