"csrc/vscode:/vscode.git/clone" did not exist on "7e1d5e5308fa3549dfed1821188d588260a03c8a"
Commit c16d506e authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class KnormPress(ScorerPress):
"""
Key norm-based KV cache compression.
Prunes key-value pairs based on L2 norm of key vectors.
Simple, efficient method requiring only norm calculation.
Based on https://arxiv.org/pdf/2406.11430.
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
return -keys.norm(dim=-1)
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Literal, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from kvpress.presses.scorer_press import ScorerPress
class KVzapConfig(PretrainedConfig):
model_type: str = "kvzap"
input_dim: int
output_dim: int
hidden_dim: Optional[int] = None
n_modules: int
class KVzapModel(PreTrainedModel):
config_class = KVzapConfig # type: ignore[assignment]
def __init__(self, config):
super().__init__(config)
self.all_tied_weights_keys = {}
if config.hidden_dim is None:
# Linear model
self.layers = nn.ModuleList(
[nn.Linear(config.input_dim, config.output_dim) for _ in range(config.n_modules)]
)
else:
# 2-layer MLP model
self.layers = nn.ModuleList(
nn.Sequential(
nn.Linear(config.input_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.output_dim),
)
for _ in range(config.n_modules)
)
def forward(self, x):
return torch.stack([module(x[:, i, :]) for i, module in enumerate(self.layers)], dim=1)
@dataclass
class KVzapPress(ScorerPress):
"""
KVzap (https://arxiv.org/abs/2601.07891) is a fast approximation of KVzip that works
in both prefilling and decoding. It applies a lightweight surrogate model to the hidden
states to predict importance scores for every KV pair.
KVzapPress is designed to be used in conjunction with the DMSPress
model_type can be "linear" or "mlp".
"""
model_type: Literal["linear", "mlp"] = "mlp"
kvzap_model_name: Optional[str] = field(default=None, init=False)
def post_init_from_model(self, model):
kvzap_model_name = f"nvidia/KVzap-{self.model_type}-{model.config.name_or_path.split('/')[-1]}"
if kvzap_model_name != self.kvzap_model_name:
self.kvzap_model_name = kvzap_model_name
self.kvzap_model = KVzapModel.from_pretrained(self.kvzap_model_name)
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> torch.Tensor:
kvzap_module = self.kvzap_model.layers[module.layer_idx]
kvzap_module = kvzap_module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
scores = kvzap_module(hidden_states).transpose(1, 2)
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from contextlib import contextmanager
from dataclasses import dataclass
from types import MethodType
from typing import Generator, List
import torch
from torch import nn
from transformers import AutoTokenizer, Gemma3PreTrainedModel, PreTrainedModel, PreTrainedTokenizer, QuantizedCache
from transformers.models.llama.modeling_llama import rotate_half
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.utils import extract_keys_and_values, get_prerope_query_states
logger = logging.getLogger(__name__)
@dataclass
class KVzipPress(BasePress):
"""
KVzip identifies the importance of KV pairs through context reconstruction,
enabling effective query-agnostic KV cache compression.
In this code, we implement KVzip with minimal changes to this repository.
For a fully optimized implementation with actual compression,
please refer to the original repository,
which also provides a version without runtime compression overhead (at the cost of performance).
Original repository (https://github.com/snu-mllab/KVzip).
Based on KVzip (https://arxiv.org/abs/2505.23416).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
layerwise : bool, default=False
Whether to enable uniform compression ratios across layers.
When False, while the overall KV cache compression ratio is maintained,
each layer has a different compression ratio.
n_sink : int, default=4
Number of initial tokens to preserve as attention sinks.
kvzip_plus_normalization: bool, default=False
Whether to enable KVzip+ normalization.
"""
compression_ratio: float = 0.0
layerwise: bool = False
n_sink: int = 4
kvzip_plus_normalization: bool = False
def __post_init__(self):
assert 0 <= self.compression_ratio < 1, "Compression ratio must be between 0 and 1"
logger.warning(
"KVzipPress requires multiple forward passes for chunked context reconstruction, "
"resulting in a computational overhead of 2–3 times the initial prefilling cost. "
"This significantly increases the overall prefilling time compared to other compression methods, "
"which is inherent to the KVzip algorithm design."
)
self._reset_internal_parameters()
def _reset_internal_parameters(self):
self.context_length = 0
self.prefix_length = 0
self._suffix_ids = None
self._context_ids = None
self._cache = None
self.score_val = None
self.causal_mask_score = None
self.start_idx = 0
self.end_idx = 0
@contextmanager
def __call__(self, model: PreTrainedModel) -> Generator:
"""
Context manager that handles both initial prefilling and KVzip scoring/compression.
This overrides the base class __call__ method to implement the full KVzip algorithm:
1. First yield: allows initial prefilling with context
2. After yield: performs KVzip scoring and compression using context reconstruction
"""
if not isinstance(model, SUPPORTED_MODELS):
logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}")
if isinstance(model, Gemma3PreTrainedModel):
raise ValueError("KVzipPress is not supported for Gemma3ForCausalLM")
# Store model reference for later use
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
# Get suffix_ids directly using tokenizer's chat template (do this once, not in hook)
if tokenizer.chat_template is None:
prefix_text = ""
suffix_text = "\n" # Default suffix for models without chat template
else:
# Use a dummy context to extract the question suffix from chat template
dummy_context = "dummy context"
separator = "\n" + "#" * len(dummy_context)
temp_context = tokenizer.apply_chat_template(
[{"role": "user", "content": dummy_context + separator}],
add_generation_prompt=True,
tokenize=False,
enable_thinking=False,
)
context, suffix_text = temp_context.split(separator)
prefix_text = context.split(dummy_context)[0]
# Tokenize suffix directly to ids
self.prefix_length = tokenizer.encode(prefix_text, return_tensors="pt", add_special_tokens=False).shape[-1]
self._suffix_ids = tokenizer.encode(suffix_text, return_tensors="pt", add_special_tokens=False)
# Register hook to store the pointer for past_key_values
original_forward = model.model.forward
def wrapped_forward(model_self, *args, **kwargs):
self._context_ids = kwargs["input_ids"]
self._cache = kwargs["past_key_values"]
return original_forward(*args, **kwargs)
model.model.forward = MethodType(wrapped_forward, model.model)
hooks = []
try:
yield
model.model.forward = original_forward # Restore original
# After yield: KVzip scoring and compression phase
if self.compression_ratio > 0 and self._context_ids is not None:
# Now register attention hooks for compression
for layer in model.model.layers:
layer.self_attn.rotary_emb = model.model.rotary_emb
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))
self._perform_kvzip_compression(model, tokenizer)
finally:
for hook in hooks:
hook.remove()
self._reset_internal_parameters()
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
Override the forward_hook of BasePress.
During the forward_hook, KVzip only calculates importance scores,
aggregates scores across all layers, and then performs compression.
"""
hidden_states = kwargs["hidden_states"]
cache = kwargs.get("past_key_values", None) or kwargs.get("past_key_value", None)
cache_layer = cache.layers[module.layer_idx]
keys, values = extract_keys_and_values(cache, module.layer_idx)
# Compute importance scores for KV pairs in the prefilled context,
# retaining only the originally prefilled KV pairs.
keys, values = self.score_kvzip(module, hidden_states, keys, values, output[1], kwargs)
if isinstance(cache, QuantizedCache):
# Update cache with compressed keys and values
cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key)
cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value)
cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.cumulative_length = keys.shape[2]
else:
cache_layer.keys = keys
cache_layer.values = values
return output
def _perform_kvzip_compression(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
"""
Perform the KVzip scoring and compression algorithm.
"""
# Prepare chunked inputs for context reconstruction
self.context_length = self._context_ids.shape[1]
chunked_context_pairs = self.prepare(model, tokenizer)
# Perform scoring through context reconstruction
# Use the stored cache from the initial forward pass
self.start_idx = self.prefix_length
for prefill_ids, repeat_ids in chunked_context_pairs:
self.end_idx = self.start_idx + prefill_ids.shape[1]
# Pass the cache that was used in the initial forward pass
model(
input_ids=repeat_ids.to(model.device),
past_key_values=self._cache,
num_logits_to_keep=1,
)
self.start_idx = self.end_idx
# Perform final compression
self.compress_post(model)
def _chunk_fn(self, ctx_ids: torch.Tensor, chunk_size: int) -> List[torch.Tensor]:
"""
Chunk input tokens
"""
ctx_len = ctx_ids.shape[1]
if ctx_len > chunk_size:
chunk_num = (ctx_len - 1) // chunk_size + 1
chunked_input_ids = []
for i in range(chunk_num):
start = i * chunk_size
end = (i + 1) * chunk_size
a_ids = ctx_ids[:, start:end]
if a_ids.shape[1] == 0:
continue
chunked_input_ids.append(a_ids)
else:
chunked_input_ids = [ctx_ids]
return chunked_input_ids
def prepare(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
chunk_size: int = 2048,
prev_postfix_size=8,
) -> List[tuple[torch.Tensor, torch.Tensor]]:
"""
Prepare chunked inputs for KV importance scoring with context reconstruction
"""
ctx_ids = self._context_ids[:, self.prefix_length :].to("cpu")
# initialize score values
self.score_val = torch.zeros(
(
model.config.num_hidden_layers,
1,
model.config.num_key_value_heads,
self.context_length,
), # only support batch size of 1
dtype=model.dtype,
device=model.device,
)
self.score_val[..., : self.n_sink] = 1.0
chunked_context_pairs = []
chunked_input_ids = self._chunk_fn(ctx_ids, chunk_size)
for i, a_ids in enumerate(chunked_input_ids):
if i == 0:
prompt = "\n\nRepeat the previous context exactly."
q_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
else:
prompt = "\n\nRepeat the part of the previous context exactly, starting with"
q_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
postfix_prev = chunked_input_ids[i - 1][:, -prev_postfix_size:]
q_ids = torch.cat([q_ids, postfix_prev], dim=1)
chunked_context_pairs.append((a_ids, torch.cat([q_ids, self._suffix_ids, a_ids], dim=1)))
return chunked_context_pairs
def _make_mask(self, attn_weights: torch.Tensor, window_size: int):
"""
Define causal mask shared across layers
"""
mask = torch.full((window_size, window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
self.causal_mask_score = mask[None, None, None, :, :]
def _mask_causal(self, attn_weights: torch.Tensor, window_size: int):
"""
Apply causal masking
"""
if self.causal_mask_score is None:
self._make_mask(attn_weights, window_size)
elif self.causal_mask_score.size(-1) != window_size:
self._make_mask(attn_weights, window_size)
self.causal_mask_score = self.causal_mask_score.to(attn_weights.device)
attn_weights[..., -window_size:, -window_size:] += self.causal_mask_score
def score_kvzip(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute the maximum cross-attention scores during context reconstruction,
and return slices of the keys and values containing only the originally prefilled KV pairs,
i.e., excluding KV pairs from repeated contexts.
The computed scores are stored in self.score_val.
"""
bsz, q_len, _ = hidden_states.shape
num_heads = module.config.num_attention_heads
num_heads_kv = module.config.num_key_value_heads
head_dim = module.head_dim
num_key_value_groups = num_heads // num_heads_kv
queries = get_prerope_query_states(module, hidden_states)
# Apply RoPE
cos, sin = kwargs["position_embeddings"]
queries = (queries * cos.unsqueeze(1)) + (rotate_half(queries) * sin.unsqueeze(1))
queries = queries.view(bsz, num_heads_kv, num_key_value_groups, q_len, head_dim)
# Subsample keys
sink = min(self.n_sink, self.start_idx)
ctx_len = self.end_idx - self.start_idx
keys_subsampled = torch.cat(
[
keys[:, :, :sink], # attention sink tokens (generally system prompt)
keys[:, :, self.start_idx : self.end_idx], # KV chunk in the cache
keys[:, :, -q_len:], # KV repeat chunk
],
dim=2,
)
keys_subsampled = keys_subsampled.unsqueeze(2).transpose(-2, -1).contiguous()
# Compute attention
attn_weights = torch.matmul(queries, keys_subsampled) / math.sqrt(head_dim)
self._mask_causal(attn_weights, q_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if self.kvzip_plus_normalization:
# Divide by ||h|| (by row)
h_norm = torch.norm(hidden_states, dim=-1)
attn_weights = torch.einsum("b h g t i, b t -> b h g t i", attn_weights, 1 / h_norm)
# Multiply by ||WoV|| (by column)
Wo = module.o_proj.weight.transpose(0, 1)
Wo = Wo.view(num_heads_kv, num_key_value_groups, module.head_dim, module.config.hidden_size)
values_subsampled = torch.cat(
[values[:, :, :sink], values[:, :, self.start_idx : self.end_idx], values[:, :, -q_len:]], dim=2
)
values_subsampled = values_subsampled.unsqueeze(2).transpose(-2, -1).contiguous()
V = values_subsampled.repeat_interleave(module.num_key_value_groups, axis=2)
WoV_norm = torch.einsum("h g i j, b h g i t -> b h g t j", Wo, V).norm(dim=-1)
attn_weights = torch.einsum("b h g i t, b h g t -> b h g i t", attn_weights, WoV_norm)
attn_weights = attn_weights[..., sink : sink + ctx_len]
scores = attn_weights.amax(dim=(-3, -2)) # max over group, q
layer_idx = int(module.layer_idx)
self.score_val[layer_idx][..., self.start_idx : self.end_idx] = scores # update score
# Retain the originally prefilled context KV pairs and exclude KV pairs from the repeated context
keys, values = keys[:, :, : self.context_length], values[:, :, : self.context_length]
return keys, values
def compress_post(self, model: PreTrainedModel):
"""
Obtain the indices of KV pairs to be evicted.
Adopted from adakv_press.compress (fake compression). KVzip does not rely on safeguards.
"""
if self.compression_ratio > 0:
n_layer, bsz, num_key_value_heads, ctx_len = self.score_val.shape
# calculate the pruned KV pairs across layers
if self.layerwise:
nl = int(bsz * num_key_value_heads * ctx_len * self.compression_ratio)
n_pruned_layers = nl * torch.ones(n_layer, device=self.score_val.device, dtype=torch.int)
else:
n_pruned_indices = int(self.score_val.numel() * self.compression_ratio)
pruned_indices = torch.topk(-self.score_val.reshape(-1), n_pruned_indices).indices
n_tokens_per_layer = bsz * num_key_value_heads * ctx_len
n_pruned_layers = torch.bincount(pruned_indices // n_tokens_per_layer, minlength=n_layer).int()
for layer in model.model.layers:
module = layer.self_attn
layer_idx = int(module.layer_idx)
assert module.config._attn_implementation != "eager", "eager mode not supported"
scores = self.score_val[layer_idx]
# Compute bottom-k across heads
n_pruned = n_pruned_layers[layer_idx].cpu()
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten().cpu()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for details
batch_indices = torch.arange(bsz, device=n_pruned.device).repeat_interleave(n_pruned)
head_indices = indices // ctx_len
seq_indices = indices % ctx_len
module.masked_key_indices = (batch_indices, head_indices, seq_indices)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class LagKVPress(ScorerPress):
"""
LagKV: Lag-relative information-based KV cache compression.
Compresses KV cache by leveraging lag-relative information between sequence
partitions. Divides sequence into partitions and uses subsequent partitions
as references for scoring tokens in prior partitions.
Based on LagKV (https://arxiv.org/abs/2504.04704).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
n_sink : int, default=4
Number of initial tokens to preserve as attention sinks.
lag_size : int, default=128
Size of each partition for lag-relative scoring.
Sequence is divided into partitions of this size, with each partition
serving as reference for scoring tokens in the previous partition.
cross_scoring : bool, default=False
Whether to enable cross-partition scoring (experimental).
When True, scoring considers cross-partition dependencies rather than
limiting to within-partition relationships. Useful with AdaKVPress.
"""
compression_ratio: float = 0.0
n_sink: int = 4
lag_size: int = 128
cross_scoring: bool = False
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
bsz, num_key_value_heads, q_len, d = keys.shape
if q_len < self.n_sink + 2 * self.lag_size:
# no compression
score = torch.ones((bsz, num_key_value_heads, q_len), dtype=keys.dtype, device=keys.device)
if q_len > self.n_sink:
# make sure the sliding part will be selected.
score[:, :, self.n_sink :] = (
torch.arange(q_len - self.n_sink, device=keys.device) / (q_len - self.n_sink)
).to(keys.dtype)
return score
end_idx = self.n_sink + ((q_len - self.n_sink) // self.lag_size) * self.lag_size
tail_len = self.lag_size + q_len - end_idx
key_score = self._get_states_score(
keys[:, :, self.n_sink : end_idx].view(bsz, num_key_value_heads, -1, self.lag_size, d)
)
value_score = self._get_states_score(
values[:, :, self.n_sink : end_idx].view(bsz, num_key_value_heads, -1, self.lag_size, d)
)
# score is in range [0, 1]
score = (key_score + value_score) / 2
if not self.cross_scoring:
score = score.argsort(dim=-1).argsort(dim=-1) / self.lag_size
score = score.to(keys.dtype)
# the parts should always keep
sink_shape = (bsz, num_key_value_heads, self.n_sink)
sink_score = torch.ones(sink_shape, dtype=score.dtype, device=score.device)
tail_shape = (bsz, num_key_value_heads, tail_len)
tail_score = torch.ones(tail_shape, dtype=score.dtype, device=score.device)
score = torch.cat((sink_score, score.reshape(bsz, num_key_value_heads, -1), tail_score), dim=-1)
return score
def _get_states_score(self, target_v):
"""evaluate the scores of keys and values for each token"""
ref = target_v[:, :, 1:, :, :]
v = target_v[:, :, :-1, :, :]
# lag-relative information
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
return score
# SPDX-FileCopyrightText: Copyright Vivek Chari
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import get_prerope_key_states
@dataclass
class LeverageScorePress(ScorerPress):
"""
Approximate leverage-score scorer on pre-RoPE keys.
Computes geometry-based outlier scores via (approximate) statistical leverage
on key embeddings using a right Gaussian sketch. Scores are z-score normalized.
The presented version slightly differs from the paper in that: we use a cholesky
decomposition to compute the leverage scores. Please see the paper for an in-depth
discussion.
References:
- Chari & Van Durme (2025): "Compactor: Calibrated Query-Agnostic KV Cache
Compression with Approximate Leverage Scores" (https://arxiv.org/pdf/2507.08143v1)
Parameters
----------
sketch_dimension : int, default ``48``
Size of Gaussian sketch.
Output
------
score(...) returns a tensor of shape (B, H_kv, S) with higher values
indicating more important tokens for retention.
Notes
-----
Currently only supports prefill.
"""
sketch_dimension: int = 48
@staticmethod
def chol_with_jitter(G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5):
"""cholesky factorization with adaptive jitter."""
identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype)
cur = float(jitter)
for _ in range(max_tries):
L, info = torch.linalg.cholesky_ex(G + cur * identity, upper=False)
if bool((info == 0).all()):
return L
cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur))
raise RuntimeError(f"Cholesky failed after {max_tries} tries.")
@staticmethod
def compute_leverage_scores(key_states: torch.Tensor, sketch_dimension: int) -> torch.Tensor:
"""
Approximate leverage scores on pre-RoPE keys via right Gaussian sketching. We
use a Cholesky solve to do this efficiently.
"""
d, k = key_states.shape[-1], sketch_dimension
# right Gaussian sketch, see paper for theoritcal analysis of this *right* sketch.
#
Phi = torch.randn(
key_states.shape[0],
key_states.shape[1],
d,
k,
device=key_states.device,
dtype=key_states.dtype,
) * (1 / math.sqrt(k))
# sequence-centering then sketch.
X = key_states - key_states.mean(dim=-2, keepdim=True)
X = torch.matmul(X, Phi).to(torch.float32) # (B,H,S,k)
XT = X.transpose(-2, -1) # (B,H,k,S)
G = XT @ X # (X^T X) / (B,H,k,k)
# After sketching, we want to compute leverage scores given by
# diag(X (X^T X)^{-1} X^T). But we don't want to form (X^T X)^{-1}
# explicitly because it is slow and numerically unstable, so we
# instead compute a Cholesky decomp G = (X^T X) = LL^T.
L = LeverageScorePress.chol_with_jitter(0.5 * (G + G.transpose(-2, -1)), jitter=1e-2, max_tries=5) # (B,H,k,k)
# we use torch.cholesky_solve(XT, L) to find Y such that GY = X^T
# given a cholesky factor L of G=LL^T (i.e we find Y = G^{-1}X^T)
inv_Xt = torch.cholesky_solve(XT, L, upper=False) # (X^TX)^{-1} X^T / (B,H,k,S)
# we can now compute the leverage scores as: diag(X (X^T X)^{-1} X^T)
# without materializing the full S x S matrix.
scores = (X * inv_Xt.transpose(-2, -1)).sum(dim=-1).clamp_min(0) # (B,H,S)
return scores
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
n_queries = hidden_states.shape[-2]
assert keys.shape[-2] == n_queries, "LeverageScorePress only supports prefill "
# pre-RoPE keys from the hidden states for the current layer
pre_rope_keys = get_prerope_key_states(module, hidden_states) # (B,H_kv,S,d)
scores = self.compute_leverage_scores(pre_rope_keys, self.sketch_dimension) # (B,H_kv,S)
z_scores = (scores - scores.mean()) / scores.std().clamp_min(1e-6)
return z_scores
# SPDX-FileCopyrightText: Copyright Vivek Chari
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import get_prerope_query_states
@dataclass
class NonCausalAttnPress(ScorerPress):
"""
Non-causal, chunked attention scorer.
This press implements the non-causal, chunked attention, sum-over-queries scoring
used in Compactor. Scores are z-normalized.
References:
- Chari & Van Durme (2025): "Compactor: Calibrated Query-Agnostic KV Cache
Compression with Approximate Leverage Scores" (https://arxiv.org/pdf/2507.08143v1)
Parameters
----------
chunk_size : int, default ``256``
Chunk size used in non-causal attention.
Output
------
score(...) returns a tensor of shape (B, H_kv, S) with higher values
indicating more important tokens for retention.
Notes
-----
Only supports prefill.
"""
chunk_size: int = 256
@staticmethod
def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor:
"""Compute non-causal, chunked attention column-sums over the sequence.
The sequence is left/right padded to a multiple of ``chunk_size`` and then
processed in fixed-size tiles.
Parameters
----------
q, k : torch.Tensor, shape (B, H, S, d)
Query/Key tensors for a single layer/head group.
chunk_size : int
Size of the chunk used to tile the sequence axis.
Returns
-------
torch.Tensor, shape (B, H, S)
Column-wise non-causal attention accumulations per key position.
"""
assert chunk_size > 0, "chunk_size must be positive"
assert q.shape[-2] == k.shape[-2], "only used in prefill"
B, H, S, d = k.shape
# pad to a multiple of chunk_size for easy reshaping
S_pad = math.ceil(S / chunk_size) * chunk_size
pad_len = S_pad - S
if pad_len > 0:
q_padded = torch.cat([q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2)
k_padded = torch.cat([k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2)
last_chunk_start = (S // chunk_size) * chunk_size
in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
else:
q_padded, k_padded = q, k
last_chunk_start = ((S - 1) // chunk_size) * chunk_size
in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
num_chunks = S_pad // chunk_size
# (B, H, num_chunks, chunk_size, d)
q_chunks = q_padded.view(B, H, num_chunks, chunk_size, d)
k_chunks = k_padded.view(B, H, num_chunks, chunk_size, d)
# (B, H, num_chunks, chunk_size, chunk_size)
dots = torch.matmul(q_chunks, k_chunks.transpose(-2, -1))
dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0)
dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9)
attn = torch.softmax(dots.to(torch.float32), dim=-1)
# sum over query and trim padding
return attn.sum(dim=-2).view(B, H, S_pad)[..., :S]
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
n_queries = hidden_states.shape[-2]
assert keys.shape[-2] == n_queries, "NonCausalAttnPress only supports prefill"
cos, sin = kwargs["position_embeddings"]
q = get_prerope_query_states(module, hidden_states) # (B, H_q, S, d)
q_len = q.shape[-2]
num_kv_groups = q.shape[1] // values.shape[1]
# apply RoPE to the queries for the last q_len positions
q = (q * cos[:, -q_len:, :].unsqueeze(1)) + (rotate_half(q) * sin[:, -q_len:, :].unsqueeze(1))
A = self.non_causal_chunked_attn(q, repeat_kv(keys, num_kv_groups), self.chunk_size) # (B, H_q, S)
# average across query-head groups back to H_kv
A = A.view(A.shape[0], values.shape[1], -1, A.shape[-1]).mean(dim=-2) # (B, H_kv, S)
scores = A * values.norm(dim=-1) # (B, H_kv, S)
scores = F.avg_pool1d(scores, kernel_size=3, padding=1, stride=1)
z_scores = (scores - scores.mean()) / scores.std().clamp_min(1e-6) # head-wise z-norm
return z_scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class ObservedAttentionPress(ScorerPress):
"""
Observed attention-based KV cache compression.
Computes importance scores based on actual attention weights observed during
forward pass. Score for each key-value pair is the average attention weight
it receives from all query tokens.
Requires: attn_implementation="eager".
Related to H2O (https://arxiv.org/abs/2306.14048).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
compression_ratio: float = 0.0
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
assert attentions is not None, 'Set attn_implementation="eager" to use this hook'
scores = attentions.sum(2)
bsz, num_key_value_heads, n_tokens, _ = keys.shape
n_tokens_in_sum = torch.arange(n_tokens, 0, -1).to(attentions.device, attentions.dtype)
scores = scores / n_tokens_in_sum
scores = scores.view(bsz, num_key_value_heads, -1, n_tokens).mean(2)
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import inspect
import logging
from dataclasses import dataclass
from typing import List
import torch
from torch import nn
from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress
logger = logging.getLogger(__name__)
@dataclass
class PerLayerCompressionPress(BasePress):
"""
Per-layer compression: Apply different compression ratios to different layers.
Wrapper that applies layer-specific compression ratios using any underlying
ScorerPress method. Different layers may have different importance patterns,
so layer-specific compression can improve quality-efficiency trade-offs.
**Important**: Experimental feature that only works with flash attention.
Parameters
----------
press : ScorerPress
The underlying scoring method to apply with layer-specific compression ratios.
compression_ratios : List[float]
List of compression ratios to apply to each layer.
Length should match number of model layers. Each value between 0.0-1.0
represents fraction of tokens to remove for that layer.
"""
press: ScorerPress
compression_ratios: List[float]
def __post_init__(self):
logger.warning(
"Per layer compression wrapper is an experimental feature and only works with flash attention. "
"Please make sure that the model uses flash attention."
)
assert (
"compression_ratio"
in inspect.signature(
self.press.__init__ # type:ignore[misc]
).parameters
), f"compression_ratio can't be set in the provided press: {self.press.__class__}"
assert isinstance(self.press, ScorerPress), "PerLayerCompressionPress requires a ScorerPress as input"
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
original_compression_ratio = self.press.compression_ratio # type:ignore[index]
self.press.compression_ratio = self.compression_ratios[module.layer_idx] # type:ignore[index]
output = self.press.forward_hook(module, input, kwargs, output)
self.press.compression_ratio = original_compression_ratio # type:ignore[attr-defined]
return output
@property
def compression_ratio(self):
return sum(self.compression_ratios) / len(self.compression_ratios)
@compression_ratio.setter
def compression_ratio(self, value):
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from kvpress.presses.base_press import BasePress
from kvpress.presses.decoding_press import DecodingPress
logger = logging.getLogger(__name__)
@dataclass
class PrefillDecodingPress(BasePress):
"""
A wrapper press that combines separate prefilling and decoding compression strategies.
This press acts as a single press interface but internally delegates to different
presses based on the current phase (prefilling vs decoding). During prefilling,
it uses the prefilling_press. During decoding, it uses the decoding_press.
Parameters
----------
prefilling_press : BasePress, optional
Press to use during the prefilling phase. If None, no compression is applied during prefilling.
decoding_press : DecodingPress, optional
Press to use during the decoding phase. If None, no compression is applied during decoding.
"""
prefilling_press: Optional[BasePress] = None
decoding_press: Optional[DecodingPress] = None
def post_init_from_model(self, model):
if self.prefilling_press is not None:
self.prefilling_press.post_init_from_model(model)
if self.decoding_press is not None:
self.decoding_press.post_init_from_model(model)
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
q_len = hidden_states.shape[1]
# Determine if we're in prefilling or decoding phase
if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None:
return self.prefilling_press.compress(module, hidden_states, keys, values, attentions, kwargs)
elif self.decoding_press is not None:
return self.decoding_press.compress(module, hidden_states, keys, values, attentions, kwargs)
# No compression applied
logger.warning("No compression applied during prefill or decoding phase")
return keys, values
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
Forward hook that delegates to the appropriate press based on current phase.
"""
hidden_states = kwargs["hidden_states"]
q_len = hidden_states.shape[1]
# Determine if we're in prefilling or decoding phase
if kwargs["cache_position"][-1] <= q_len and self.prefilling_press is not None:
return self.prefilling_press.forward_hook(module, input, kwargs, output)
elif self.decoding_press is not None:
return self.decoding_press.forward_hook(module, input, kwargs, output)
# No hook applied
return output
@contextmanager
def __call__(self, model: PreTrainedModel):
try:
with super().__call__(model):
yield
finally:
# Reset decoding press if it exists
if self.decoding_press is not None:
self.decoding_press.reset()
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.snapkv_press import SnapKVPress
logger = logging.getLogger(__name__)
@dataclass
class PyramidKVPress(SnapKVPress):
"""
PyramidKV: Layer-wise adaptive KV cache allocation with pyramid structure.
Dynamically adjusts KV cache sizes across transformer layers, allocating
more tokens to lower layers and fewer to higher layers. Based on the
observation that lower layers need more context while higher layers
can work with less.
Based on PyramidKV (https://arxiv.org/abs/2406.02069).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
window_size : int, default=64
Base window size for attention computation, used in pyramid budget calculation.
kernel_size : int, default=5
Size of the pooling kernel for attention smoothing (inherited from SnapKV).
beta : int, default=20
Hyperparameter controlling the pyramid's shape and steepness.
Larger values create steeper pyramids with more dramatic differences between
layers. Smaller values create gentler, more balanced allocation across layers.
"""
compression_ratio: float = 0.0
window_size: int = 64
kernel_size: int = 5
beta: int = 20
def get_layer_budget(
self,
module: nn.Module,
q_len: int,
) -> int:
"""
Compute the budget for each layer based on the pyramid shape.
We use the budget calculation formula from:
https://github.com/Zefan-Cai/KVCache-Factory/blob/main/pyramidkv/pyramidkv_utils.py#L197
This implementation always applies compression_ratio,
instead of disabling compression or keeping fixed budget for short queries like the original code.
max_capacity_prompt is calculated as:
max_num + min_num &= (max_capacity_prompt - window_size) * 2
total_kvcache_size &= \frac{(max_num + min_num) * num_layers}{2}
&= (max_capacity_prompt - window_size) * num_layers
total_kvcache_size &= query_length * num_layers * (1 - compression_ratio)
max_capacity_prompt &= window_size + query_length * (1 - compression_ratio)
"""
assert self.beta >= 1, "Beta should >= 1"
# Ensure the total budget meets the compression_ratio requirements
max_capacity_prompt = self.window_size + q_len * (1 - self.compression_ratio)
min_num = (max_capacity_prompt - self.window_size) / self.beta
max_num = (max_capacity_prompt - self.window_size) * 2 - min_num
if max_num >= q_len - self.window_size:
max_num = q_len - self.window_size
min_num = (max_capacity_prompt - self.window_size) * 2 - max_num
if not (q_len >= max_num >= min_num >= self.window_size):
# Fall back to SnapKV
return round(q_len * (1 - self.compression_ratio))
steps = (max_num - min_num) / (module.config.num_hidden_layers - 1)
return round(max_num - module.layer_idx * steps)
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.compression_ratio == 0:
return keys, values
# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Get indices of KV pairs with the lowest scores
k_len = keys.shape[2]
n_kept = self.get_layer_budget(module, k_len)
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
# Prune keys and values
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
return keys, values
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from functools import cache
import torch
from huggingface_hub import PyTorchModelHubMixin, get_collection
from kvpress.presses.scorer_press import ScorerPress
class QFilters(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, num_layers: int, num_kv_heads: int, kv_head_dim: int):
super().__init__()
self.q_filters = torch.nn.Parameter(torch.randn(num_layers, num_kv_heads, kv_head_dim))
@dataclass
class QFilterPress(ScorerPress):
"""
Q-Filter: Learned filter-based KV cache compression.
This method uses pre-trained learned filters (Q-filters) to score and compress
key-value pairs. Unlike heuristic-based methods,
Q-filters are vectors that identify important tokens for specific model architectures.
The method works by:
1. Loading pre-trained Q-filter parameters for the specific model
2. Computing dot products between keys and the learned filters
3. Using these dot products as importance scores for compression
4. Pruning tokens with the lowest filter response scores
Key characteristics:
- Uses learned parameters rather than heuristics
- Model-specific filters optimized for each architecture
- Potentially more accurate than generic scoring methods
- Requires pre-trained filter parameters to be available
The Q-filters are automatically loaded based on the model name and are
expected to be available in a Hugging Face model collection.
Based on Q-Filter (https://arxiv.org/abs/2503.02812).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
q_filters: QFilters = field(init=False, default=None)
def post_init_from_model(self, model):
model_name = model.config.name_or_path.split("/")[-1]
self.q_filters = self.load_q_filters(model_name)
self.q_filters = self.q_filters.to(model.dtype)
@staticmethod
@cache
def load_q_filters(model_name):
model_name = model_name if "Meta-Llama-3.1-405B" in model_name else model_name.replace("Meta-Llama", "Llama")
try:
return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters
except TypeError:
raise ValueError(
f"Could not load Q-filters for {model_name}. Available models: {QFilterPress.available_qfilters()}"
)
@staticmethod
def available_qfilters():
collection = get_collection("nthngdy/q-filters-67a4994dcb302a3d37f3d119", token=False)
return [x.item_id.split("/")[-1][:-6] for x in collection.items]
def score(self, module, hidden_states, keys, values, attentions, kwargs):
if self.q_filters is None:
raise ValueError(
"Q-filters not loaded. If you are using a wrapper press, make sure to call post_init_from_model."
)
q_filter = self.q_filters[module.layer_idx][None, :, None] # type: ignore
q_filter = q_filter.to(keys.device)
scores = -(q_filter * keys).sum(dim=-1)
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class RandomPress(ScorerPress):
"""
Random KV cache compression for baseline comparison.
Randomly selects which key-value pairs to prune. Useful for establishing baseline
performance metrics and validating other compression methods.
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
seed : int, optional
Random seed for reproducible compression results.
"""
compression_ratio: float = 0.0
seed: Optional[int] = None
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
generator = None
if self.seed is not None:
generator = torch.Generator()
generator.manual_seed(self.seed)
return torch.rand(*keys.shape[:-1], generator=generator, device=keys.device, dtype=keys.dtype)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.base_press import BasePress
logger = logging.getLogger(__name__)
@dataclass
class ScorerPress(BasePress):
"""
Base class for score-based KV cache compression methods.
This class assigns scores to key-value pairs and prune those with the lowest scores.
Subclasses then implement the `score` method to define how importance is calculated.
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
compression_ratio: float = 0.0
def __post_init__(self):
assert 0 <= self.compression_ratio < 1, "Compression ratio must be between 0 and 1"
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
"""
Compute importance scores for each key-value pair.
This method must be implemented by subclasses to define how the importance
of each token position is calculated. Higher scores indicate more important
tokens that should be kept during compression.
Parameters
----------
module : nn.Module
The transformer attention layer where scoring is applied.
hidden_states : torch.Tensor
Input embeddings with shape (batch_size, seq_len, hidden_dim).
keys : torch.Tensor
Key tensors with shape (batch_size, num_kv_heads, seq_len, head_dim).
values : torch.Tensor
Value tensors with shape (batch_size, num_kv_heads, seq_len, head_dim).
attentions : torch.Tensor
Attention weights with shape (batch_size, num_heads, seq_len, seq_len).
May be None if not computed or needed by the scoring method.
kwargs : dict
Additional arguments from the forward pass, including cache and position info.
Returns
-------
torch.Tensor
Importance scores with shape (batch_size, num_kv_heads, seq_len).
Higher scores indicate more important tokens. The tokens with the
lowest scores will be pruned during compression.
"""
raise NotImplementedError
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.compression_ratio == 0:
return keys, values
# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Get indices of KV pairs with the lowest scores
k_len = keys.shape[2]
n_kept = int(k_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
# Prune keys and values
keys = keys.gather(2, indices).contiguous()
values = values.gather(2, indices).contiguous()
return keys, values
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
logger = logging.getLogger(__name__)
@dataclass
class SimLayerKVPress(BasePress):
"""
SimLayerKV: Similarity-based layer-wise KV cache compression.
Identifies "lazy" layers that can work effectively with reduced KV cache sizes.
If a layer is considered "lazy", we only keep the initial and recent KV pairs.
Otherwise, we keep all KV pairs.
Recommended lazy_threshold values: Llama3 (0.9), Llama2 (0.65), Mistral (0.8), Qwen (0.85).
Based on SimLayerKV (https://arxiv.org/abs/2410.13846).
Parameters
----------
lazy_threshold : float, default=1.0
Threshold for identifying lazy layers based on attention concentration.
Layer is lazy if sum(attention_weights[last_tokens -> initial+recent]) > threshold.
Lower values identify more layers as lazy (more aggressive compression).
n_last : int, default=1
Number of last tokens to analyze for lazy layer identification.
n_recent : int, default=1024
Number of recent tokens to preserve in lazy layers.
n_initial : int, default=4
Number of initial tokens to preserve in lazy layers (sink tokens).
"""
lazy_threshold: float = 1.0
n_last: int = 1 # n_last=1 to match SKLV-decode
n_recent: int = 1024
n_initial: int = 4
def __post_init__(self):
assert 0.0 <= self.lazy_threshold <= 1.0, "lazy_threshold should be in [0, 1]"
self.compression_ratios = []
def is_lazy(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
position_embeddings: torch.Tensor,
) -> bool:
"""
Compute the attention weights of the last tokens over the initial and recent tokens.
The layer is considered lazy if the sum of these attention weights is above the lazy_threshold.
"""
attn_weights = SnapKVPress.compute_window_attention(
module, hidden_states, keys, self.n_last, position_embeddings
)
attn_weights = attn_weights.mean((0, 1, 2)) # mean over bsz, heads and window size
score = attn_weights[: self.n_initial].sum() + attn_weights[-self.n_recent :].sum()
return score.item() > self.lazy_threshold
@property
def compression_ratio(self):
if len(self.compression_ratios) > 0:
return sum(self.compression_ratios) / len(self.compression_ratios)
else:
raise ValueError("Forward pass must be run to compute the compression ratio")
@compression_ratio.setter
def compression_ratio(self, value):
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
# Initialize the compression ratios
if module.layer_idx == 0:
self.compression_ratios = []
# Check if compression is needed
k_len = keys.shape[2]
min_length = self.n_initial + self.n_recent + self.n_last
if k_len <= min_length:
logger.warning(f"Sequence length is shorter than {min_length}: no compression applied")
if (self.lazy_threshold == 1.0) or (k_len <= min_length):
self.compression_ratios.append(0.0)
return keys, values
# Compression
if self.is_lazy(module, hidden_states, keys, kwargs["position_embeddings"]):
# If layer is lazy, only keep the initial and recent KV pairs
keys = torch.cat([keys[:, :, : self.n_initial], keys[:, :, -self.n_recent + self.n_last :]], dim=2)
values = torch.cat([values[:, :, : self.n_initial], values[:, :, -self.n_recent + self.n_last :]], dim=2)
self.compression_ratios.append((k_len - self.n_initial - self.n_recent + 1) / k_len)
else:
self.compression_ratios.append(0.0)
return keys, values
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
from kvpress.presses.scorer_press import ScorerPress
from kvpress.utils import get_prerope_query_states
@dataclass
class SnapKVPress(ScorerPress):
"""
SnapKV: Attention-based KV cache compression using recent token patterns.
Uses attention patterns of the most recent tokens to estimate importance
of previous key-value pairs.
Based on SnapKV (https://arxiv.org/abs/2404.14469).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
window_size : int, default=64
Number of recent tokens to use for computing attention-based importance scores.
kernel_size : int, default=5
Size of the pooling kernel applied to attention weights for smoothing.
"""
compression_ratio: float = 0.0
window_size: int = 64
kernel_size: int = 5
@staticmethod
def compute_window_attention(module, hidden_states, keys, window_size, position_embeddings):
"""
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""
bsz, _, k_len, _ = keys.shape
num_heads = module.config.num_attention_heads
head_dim = module.head_dim
num_key_value_groups = num_heads // module.config.num_key_value_heads
# Get last window_size queries
query_states = get_prerope_query_states(module, hidden_states[:, -window_size:])
# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -window_size:], sin[:, -window_size:]
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))
# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attention_mask = torch.ones_like(attn_weights) * float("-inf")
attention_mask = torch.triu(attention_mask, diagonal=k_len - window_size + 1)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = attn_weights[..., :-window_size]
return attn_weights
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
bsz, num_key_value_heads, k_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
assert (
hidden_states.shape[1] > self.window_size
), f"Query length {hidden_states.shape[1]} should be greater than the window size {self.window_size}"
if attentions is not None:
attn_weights = attentions[..., -self.window_size :, : -self.window_size]
else:
attn_weights = self.compute_window_attention(
module, hidden_states, keys, self.window_size, kwargs["position_embeddings"]
)
scores = attn_weights.mean(dim=-2)
scores = F.avg_pool1d(scores, kernel_size=self.kernel_size, padding=self.kernel_size // 2, stride=1)
# Average per group (https://github.com/FasterDecoding/SnapKV/issues/22)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, k_len - self.window_size)
scores = scores.mean(2)
# Add back the observation window. Use max score to make sure the window is not pruned.
scores = F.pad(scores, (0, self.window_size), value=scores.max().item())
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class StreamingLLMPress(ScorerPress):
"""
StreamingLLM: Window-based KV cache compression with sink tokens.
Implements sliding window approach preserving first few tokens (sink tokens)
and most recent tokens, while pruning middle tokens.
Based on StreamingLLM (https://arxiv.org/abs/2309.17453).
To fully match the implementation described in the paper, use the KeyRerotationPress wrapper (see issue #158).
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
n_sink : int, default=4
Number of initial tokens to always preserve (sink tokens).
These tokens are never pruned and serve as "attention sinks" that help
maintain model stability. Language models often assign high attention
weights to early tokens regardless of semantic content.
"""
compression_ratio: float = 0.0
n_sink: int = 4
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
k_len = keys.shape[2]
assert k_len > self.n_sink, f"Input should contain more tokens than n_sink={self.n_sink}"
n_pruned = k_len - int(k_len * (1 - self.compression_ratio))
scores = torch.ones_like(keys[..., 0])
scores[:, :, self.n_sink : self.n_sink + n_pruned] = 0
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from torch import nn
from transformers.models.llama.modeling_llama import rotate_half
from kvpress.presses.base_press import BasePress
from kvpress.utils import get_prerope_query_states
@dataclass
class ThinKPress(BasePress):
"""
ThinK: Channel-wise key compression for transformer attention.
ThinK compresses the dimensions of the keys, and not the sequence length.
Hence it can be combined with any other press that compresses the sequence length, e.g.
press = ComposedPress([SnapKVPress(0.5), ThinKPress(0.5)])
Here, we zero out the pruned dimensions resulting in no memory gain (the shape of the keys remains the same).
To achieve memory savings, several options can be considered (see https://github.com/NVIDIA/kvpress/pull/18/),
we might implement them in the future, especially if other similar presses are requested.
This press has been reviewed by Yuhui Xu, first author of the ThinK paper.
Based on ThinK (https://arxiv.org/pdf/2407.21018).
Parameters
----------
key_channel_compression_ratio : float, default=0.0
Fraction of key channels (dimensions) to remove during compression.
window_size : int, default=32
Number of recent tokens to use for computing key channel importance.
"""
key_channel_compression_ratio: float = 0.0
window_size: int = 32
def compute_window_queries(self, module, hidden_states, position_embeddings):
"""
Re-compute the last window_size query states
"""
# Get last self.window_size queries
query_states = get_prerope_query_states(module, hidden_states[:, -self.window_size :])
# Apply RoPE
cos, sin = position_embeddings
cos, sin = cos[:, -self.window_size :], sin[:, -self.window_size :]
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))
return query_states
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
If other similar presses are requested, we might create a generic compress method for dimension pruning
to avoid code duplication.
"""
if self.key_channel_compression_ratio == 0:
return keys, values
# Compute scores per dimension
bsz, num_key_value_heads, k_len, head_dim = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
queries = self.compute_window_queries(module, kwargs["hidden_states"], kwargs["position_embeddings"])
queries_norm = torch.pow(queries, 2).mean(dim=2) # (bsz, num_heads, head_dim)
queries_norm = queries_norm.view(bsz, num_key_value_heads, num_key_value_groups, module.head_dim).mean(2)
keys_norm = torch.pow(keys, 2).mean(dim=2)
key_scores = queries_norm * keys_norm # (bsz, num_key_value_heads, head_dim)
# Prune dimensions with the lowest scores by setting them to 0
n_pruned = int(head_dim * self.key_channel_compression_ratio)
indices = key_scores.topk(n_pruned, dim=-1, largest=False).indices
indices = indices.unsqueeze(2).expand(-1, -1, k_len, -1)
keys = keys.scatter_(-1, indices, 0)
return keys, values
@property
def compression_ratio(self):
return self.key_channel_compression_ratio / 2
@compression_ratio.setter
def compression_ratio(self, value):
raise AttributeError(f"compression ratio cannot be set for {type(self).__name__}")
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.snapkv_press import SnapKVPress
@dataclass
class TOVAPress(ScorerPress):
"""
TOVA: Token-wise Optimal Value Attention for KV cache compression.
Uses attention weights of the last token (averaged across heads) to estimate
importance of previous key-value pairs. The last token's attention pattern
provides a good indicator of which historical tokens are important.
Based on TOVA (https://arxiv.org/abs/2401.06104).
Official implementation can be found here: https://github.com/schwartz-lab-NLP/TOVA/blob/main/src/tova_cache.py
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
compression_ratio: float = 0.0
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
if attentions is not None:
attn_weights = attentions[..., -1:, :-1]
else:
attn_weights = SnapKVPress.compute_window_attention(
module, hidden_states, keys, 1, kwargs["position_embeddings"]
)
# Average across heads and repeat num_key_value_head times
scores = attn_weights.mean(1)
scores = scores.repeat(1, keys.shape[1], 1)
# Add back the last token. Use max score to make sure the window is not pruned.
# This is a very slight difference from TOVA that don't enforce it, but the
# last attention weight is usually very high so it should not change the results.
scores = F.pad(scores, (0, 1), value=scores.max().item())
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
from transformers import Cache, QuantizedCache
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
from transformers.models.phi3.modeling_phi3 import Phi3Attention
from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention
def get_prerope_query_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Extracts the query states from a given attention module and hidden states tensor.
This function supports multiple attention module types: Phi3Attention, Qwen3Attention, Gemma3Attention,
and Llama-like modules. It handles the appropriate projection and reshaping to obtain the query states
in the expected format.
Parameters
----------
module : nn.Module
The attention module from which to extract query states. Must be one of
Phi3Attention, Qwen3Attention, Gemma3Attention, or a Llama-like attention module
with a 'q_proj' attribute.
hidden_states : torch.Tensor
The input hidden states of shape (batch_size, seq_len, hidden_dim).
Returns
-------
query_states : torch.Tensor
The extracted query states of shape (batch_size, num_heads, seq_len, head_dim).
"""
bsz, q_len, _ = hidden_states.shape
num_heads = module.config.num_attention_heads
head_dim = module.head_dim
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states)
query_states = qkv[..., : num_heads * head_dim]
elif hasattr(module, "q_proj"):
# Assume Llama-like attention layer
query_states = module.q_proj(hidden_states)
else:
raise NotImplementedError(f"Press not yet implemented for {module.__class__}.")
query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
query_states = module.q_norm(query_states)
return query_states
def get_prerope_key_states(module: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Extracts the key states from a given attention module and hidden states tensor.
This function supports multiple attention module types: Phi3Attention, Qwen3Attention, Gemma3Attention,
and Llama-like modules. It handles the appropriate projection and reshaping to obtain the key states
in the expected format.
Parameters
----------
module : nn.Module
The attention module from which to extract key states. Must be one of
Phi3Attention, Qwen3Attention, Gemma3Attention, or a Llama-like attention module
with a 'k_proj' attribute.
hidden_states : torch.Tensor
The input hidden states of shape (batch_size, seq_len, hidden_dim).
Returns
-------
key_states : torch.Tensor
The extracted key states of shape (batch_size, num_heads, seq_len, head_dim).
"""
bsz, k_len, _ = hidden_states.shape
head_dim = module.head_dim
if isinstance(module, Phi3Attention):
qkv = module.qkv_proj(hidden_states)
query_pos = module.config.num_attention_heads * module.head_dim
key_states = qkv[..., query_pos : query_pos + module.num_key_value_heads * module.head_dim]
elif hasattr(module, "k_proj"):
# Assume Llama-like attention layer
key_states = module.k_proj(hidden_states)
else:
raise NotImplementedError(f"Press not yet implemented for {module.__class__}.")
key_states = key_states.view(bsz, k_len, -1, head_dim).transpose(1, 2)
# Support for Qwen3 and Gemma3 QK norm
if isinstance(module, (Qwen3Attention, Gemma3Attention)):
key_states = module.k_norm(key_states)
return key_states
def dequantize_layer(cache_layer) -> tuple[torch.Tensor, torch.Tensor]:
keys = cache_layer._dequantize(cache_layer._quantized_keys)
values = cache_layer._dequantize(cache_layer._quantized_values)
return keys, values
def extract_keys_and_values(cache: Cache, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Extracts the keys and values from a given cache layer,
handling both quantized and unquantized caches.
"""
if isinstance(cache, QuantizedCache):
keys, values = dequantize_layer(cache.layers[layer_idx])
else:
keys = cache.layers[layer_idx].keys
values = cache.layers[layer_idx].values
return keys, values
# KVzap
[![KVzap collection](https://img.shields.io/badge/🤗%20Hugging%20Face-Collection-orange)](https://huggingface.co/collections/nvidia/kvzap)
[![arXiv](https://img.shields.io/badge/arXiv-2601.07891-b31b1b.svg)](https://arxiv.org/abs/2601.07891)
[KVzap](https://arxiv.org/abs/2601.07891) is a fast approximation of [KVzip](https://arxiv.org/abs/2505.23416) that works in both prefilling and decoding. It applies a lightweight surrogate model to the hidden states to predict importance scores, and removes the KV pairs with a score below a given threshold, following the Dynamic Memory Sparsification ([DMS](https://arxiv.org/abs/2506.05345)) inference strategy.
## Usage
KVzap is designed to be used by combining the `KVzapPress` and the `DMSPress` from kvpress:
```python
import requests
from transformers import pipeline
from kvpress import KVzapPress, DMSPress
model = "Qwen/Qwen3-8B"
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", dtype="auto")
press = DMSPress(KVzapPress(model_type="mlp"), threshold=-4)
# Prefilling compression only, thinking disabled
press.decoding = False
context = requests.get("https://arxiv.org/abs/2601.07891").text
question = "\n What is this article about in 2 sentences ?"
answer = pipe(context, question=question, press=press)["answer"]
print(f"Compression ratio: {press.compression_ratio:.2%}\nAnswer: {answer}")
# Prefilling and decoding compression, thinking enabled
press.decoding = True
prompt = "What is the best hardware to run LLMs and why ?"
answer = pipe(prompt, press=press, enable_thinking=True, max_new_tokens=2000)["answer"]
print(f"Compression ratio: {press.compression_ratio:.2%}\nAnswer: {answer}")
```
The `KVzapPress` inherits from the `ScorerPress` class and only predicts the scores for every KV pair. The `DMSPress` then prunes the KV pairs with a score below a given threshold, rather than using a fixed compression ratio.
Supported base models are provided in the [KVzap collection](https://huggingface.co/collections/nvidia/kvzap) but can easily be extended to any other model following the instructions in the [training section](#training).
## Training
Training uses the [Nemotron-Pretraining-Dataset-sample](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-Dataset-sample) to extract KVzip+ scores and train surrogate models.
To reproduce the training or train your own model, use the following command:
```bash
pip install skorch scikit-learn
python train.py --model_name <model_name> --output_dir <output_dir>
```
Run `python train.py --help` for all options.
## Evaluation
Evaluation can be reproduced by using the [kvpress evaluation CLI](../evaluation).
We provide a specific script to evaluate KVzap on the AIME25 benchmark using `model.generate` directly to enable sampling-based decoding rather than greedy decoding:
```bash
python evaluate_aime.py <model_type> --threshold <threshold> --model_name <base_model_name>
```
where `<model_type>` is the type of KVzap model to use ("mlp", "linear" or "no_press") and `<base_model_name>` the name of the base model to use (e.g. "Qwen/Qwen3-8B").
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment