Commit cfe92c69 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import contextlib
import logging
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor
from kvpress.presses.base_press import BasePress
from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.dms_press import DMSPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
logger = logging.getLogger(__name__)
class KVPressTextGenerationPipeline(Pipeline):
"""
Pipeline for key-value cache compression in causal language models.
Enables efficient processing of long contexts by applying KV cache compression
during pre-filling, then generating answers using greedy decoding.
Example:
```python
pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
press = SnapKVPress(compression_ratio=0.5)
result = pipeline(context="Long text...", question="A question about the long context.", press=press)
```
"""
def _sanitize_parameters(
self,
question: Optional[str] = None,
questions: Optional[list[str]] = None,
answer_prefix: Optional[str] = None,
press: Optional[BasePress] = None,
max_new_tokens: int = 50,
max_context_length: Optional[int] = None,
enable_thinking: bool = False,
cache: Optional[Cache] = None,
**kwargs,
):
"""
Sanitize the input parameters for the pipeline.
The user can either provide a single question or a list of questions to be asked about the context.
Parameters
----------
question : str, optional
The question to be asked about the context. Exclusive with `questions`.
questions : list[str], optional
A list of questions to be asked about the context. Exclusive with `question`.
answer_prefix : str, optional
The prefix to be added to the generated answer.
press : BasePress, optional
The key-value cache compression method to apply during pre-filling.
Accepts any KVPress compression method (SnapKVPress, KnormPress,
ExpectedAttentionPress, BlockPress, AdaKVPress, ComposedPress, etc.).
If None, no compression is applied.
max_new_tokens : int, optional
The maximum number of new tokens to generate for each answer.
max_context_length : int, optional
The maximum number of tokens in the context. By default will use the maximum length supported by the model.
enable_thinking: bool = False,
Whether to enable thinking in the chat template (chat template must support this argument)
cache : Cache, optional
The cache to use for the forward pass. Defaults to None (DynamicCache).
**kwargs : dict
Additional keyword arguments, currently ignored.
Returns
-------
Tuple[dict, dict, dict]
A tuple containing three dictionaries:
- preprocess_kwargs: The keyword arguments for the preprocess function.
- forward_kwargs: The keyword arguments for the forward function.
- postprocess_kwargs: The keyword arguments for the postprocess function.
"""
answer_prefix = answer_prefix or ""
postprocess_kwargs = {"single_question": questions is None}
assert question is None or questions is None, "Either question or questions should be provided, not both."
questions = questions or ([question] if question else [""])
if max_context_length is None:
max_context_length = min(self.tokenizer.model_max_length, int(1e10)) # 1e10 to avoid overflow
preprocess_kwargs = {
"questions": questions,
"answer_prefix": answer_prefix,
"max_context_length": max_context_length,
"enable_thinking": enable_thinking,
}
forward_kwargs = {"press": press, "max_new_tokens": max_new_tokens, "cache": cache}
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
def preprocess(
self,
context: str,
questions: list[str],
answer_prefix: str,
max_context_length: int,
enable_thinking: bool = False,
):
"""
Apply chat template and tokenize the context and questions.
Prepares input text for KV cache compression and generation by applying
appropriate chat templates and tokenizing. Handles models with and without
chat templates.
Parameters
----------
context : str
Long context text to be compressed using the press method.
questions : list[str]
Questions to be asked about the context.
answer_prefix : str
Optional prefix for generated answers.
max_context_length : int
Maximum tokens allowed in context (truncated if exceeded).
enable_thinking : bool
Whether to enable thinking in the chat template (chat template must support this argument)
Returns
-------
dict[str, GenericTensor]
Dictionary with "context_ids" and "questions_ids" tensors.
"""
# Apply chat template if available
if self.tokenizer.chat_template is None:
bos_token = getattr(self.tokenizer, "bos_token", "")
context = bos_token + context
question_suffix = "\n" # to separate the question from the answer
else:
separator = "#" * (len(context) + 10)
context = self.tokenizer.apply_chat_template(
[{"role": "user", "content": context + separator}],
add_generation_prompt=True,
tokenize=False,
enable_thinking=enable_thinking,
)
context, question_suffix = context.split(separator)
# Add question_suffix and answer prefix
# e.g. for llama3.1, question_suffix="<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")
questions = [question + question_suffix + answer_prefix for question in questions]
# Tokenize the context and questions
context_ids = self.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False)
question_ids = [
self.tokenizer.encode(question, return_tensors="pt", add_special_tokens=False) for question in questions
]
# Truncate context
if context_ids.shape[1] > max_context_length:
logger.warning(
f"Context length has been truncated from {context_ids.shape[1]} to {max_context_length} tokens."
)
context_ids = context_ids[:, :max_context_length]
return {"context_ids": context_ids, "questions_ids": question_ids}
def _forward(
self,
input_tensors: dict[str, GenericTensor],
max_new_tokens: int = 50,
press: Optional[BasePress] = None,
cache: Optional[Cache] = None,
):
"""
Execute KV cache compression and text generation pipeline.
Performs context compression using the press method during pre-filling,
then generates answers using greedy decoding.
Parameters
----------
input_tensors : dict[str, GenericTensor]
Tokenized inputs with "context_ids" and "questions_ids".
max_new_tokens : int, default=50
Maximum tokens to generate for each answer.
press : BasePress, optional
Compression method for context pre-filling. If None, no compression.
cache : Cache, optional
Cache object for forward pass. If None, creates new DynamicCache.
Returns
-------
list[str]
Generated answers for each input question.
"""
if isinstance(press, (DecodingPress, PrefillDecodingPress)) and len(input_tensors["questions_ids"]) > 1:
raise ValueError(
"DecodingPress is not compatible with multiple questions. Please specify a single question."
)
context_ids = input_tensors["context_ids"].to(self.model.device)
context_length = context_ids.shape[1]
# Prefilling using the press on the context
if cache is None:
cache = DynamicCache()
# We only perform prefill compression if the press is a prefill press
perform_prefill_compression = press is not None and not isinstance(press, DecodingPress)
with press(self.model) if perform_prefill_compression else contextlib.nullcontext():
# We run the model without the lm head for pre-filling.
self.model.model(
input_ids=context_ids,
past_key_values=cache,
)
logger.debug(f"Context Length: {context_length}")
logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")
# We only perform decoding compression if the press is a decoding or prefill decoding press
perform_decoding_compression = press is not None and isinstance(press, (DecodingPress, PrefillDecodingPress))
if isinstance(press, DMSPress):
perform_decoding_compression = press.decoding
with press(self.model) if perform_decoding_compression else contextlib.nullcontext():
# Greedy decoding for each question
answers = []
for question_ids in input_tensors["questions_ids"]:
if isinstance(press, KeyRerotationPress) or (isinstance(press, FinchPress) and press.rerotate_keys):
context_length = cache.get_seq_length()
cache_seq_lengths = [cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))]
answer = self.generate_answer(
question_ids=question_ids.to(self.model.device),
cache=cache,
context_length=context_length,
max_new_tokens=max_new_tokens,
)
self._remove_answer_from_cache(cache, cache_seq_lengths)
answers.append(answer)
return answers
def _remove_answer_from_cache(self, cache: Cache, cache_seq_lengths: list[int]):
for layer_idx, sequence_length in enumerate(cache_seq_lengths):
cache.layers[layer_idx].keys = cache.layers[layer_idx].keys[:, :, :sequence_length]
cache.layers[layer_idx].values = cache.layers[layer_idx].values[:, :, :sequence_length]
if isinstance(cache, QuantizedCache):
for layer_idx, sequence_length in enumerate(cache_seq_lengths):
cache.layers[layer_idx]._quantized_keys = cache.layers[layer_idx]._quantized_keys[
:, :, :sequence_length
]
cache.layers[layer_idx]._quantized_values = cache.layers[layer_idx]._quantized_values[
:, :, :sequence_length
]
def generate_answer(
self, question_ids: torch.Tensor, cache: Cache, context_length: int, max_new_tokens: int
) -> str:
"""
Generate an answer to a question using greedy decoding.
Parameters
----------
question_ids : torch.Tensor
The tokenized question.
cache : Cache
The compressed key-value cache.
context_length : int
The length of the context.
max_new_tokens : int
The maximum number of new tokens to generate.
Returns
-------
str
The generated answer.
"""
position_ids = torch.arange(
context_length, context_length + question_ids.shape[1], device=self.model.device
).unsqueeze(0)
# if the user doesn't provide a question, skip forward pass
outputs = self.model(
input_ids=question_ids.to(self.model.device),
past_key_values=cache,
position_ids=position_ids,
num_logits_to_keep=1,
)
position_ids = position_ids[:, -1:] + 1
generated_ids = [outputs.logits[0, -1].argmax()]
should_stop_token_ids = self.model.generation_config.eos_token_id
if not isinstance(should_stop_token_ids, list):
should_stop_token_ids = [should_stop_token_ids]
for i in range(max_new_tokens - 1):
outputs = self.model(
input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0),
past_key_values=cache,
position_ids=position_ids + i,
)
new_id = outputs.logits[0, -1].argmax()
generated_ids.append(new_id)
if new_id.item() in should_stop_token_ids:
break
answer = str(self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True))
return answer
def postprocess(self, model_outputs, single_question):
if single_question:
return {"answer": model_outputs[0]}
return {"answers": model_outputs}
PIPELINE_REGISTRY.register_pipeline(
"kv-press-text-generation",
pipeline_class=KVPressTextGenerationPipeline,
pt_model=AutoModelForCausalLM,
)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from kvpress.presses.base_press import BasePress
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class AdaKVPress(BasePress):
"""
AdaKV: Adaptive head-wise KV cache compression.
Performs head-specific compression by selecting top-k tokens across all heads
based on importance scores. Applies safeguards to ensure each head retains
a minimum fraction of tokens.
Based on AdaKV (https://arxiv.org/abs/2407.11550).
Parameters
----------
press : ScorerPress
AdaKVPress and ObservedAttention are currently not supported.
alpha_safeguard : float, default=0.20
Minimum fraction of KV pairs that each head must retain.
Ensures no attention head is compressed too aggressively. Even if tokens
receive low global importance scores, each head retains at least this
fraction of its original tokens.
"""
press: ScorerPress
alpha_safeguard: float = 0.20
def __post_init__(self):
assert isinstance(self.press, ScorerPress), "AdaKVPress requires a ScorerPress as input"
assert 0 <= self.alpha_safeguard <= 1, "alpha_safeguard should be in [0, 1]"
def post_init_from_model(self, model):
self.press.post_init_from_model(model)
@property
def compression_ratio(self):
return self.press.compression_ratio
@compression_ratio.setter
def compression_ratio(self, value):
self.press.compression_ratio = value
def compress(self, module, hidden_states, keys, values, attentions, kwargs):
if self.compression_ratio == 0:
return keys, values
assert module.config._attn_implementation != "eager", "eager mode not supported"
# Compute scores
scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
bsz, num_key_value_heads, k_len = scores.shape
# Make sure to keep at least alpha * (1 - compression_ratio) KV pairs per head
n_kept = int(k_len * (1 - self.compression_ratio)) # ScorerPress definition
n_safe = int(n_kept * self.alpha_safeguard)
top_indices = torch.topk(scores, n_safe, dim=-1).indices
scores.scatter_(-1, top_indices, torch.finfo(scores.dtype).max)
# Compute bottom-k across heads
n_pruned = num_key_value_heads * (k_len - n_kept)
indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten()
# Save indices to mask during the attention mechanism. Please refer to attention_patch.py for more details
batch_indices = torch.arange(bsz).repeat_interleave(n_pruned)
head_indices = indices // k_len
seq_indices = indices % k_len
module.masked_key_indices = (batch_indices, head_indices, seq_indices)
return keys, values
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Generator
import torch
from torch import nn
from transformers import (
Gemma3ForConditionalGeneration,
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
PreTrainedModel,
QuantizedCache,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
)
from kvpress.utils import extract_keys_and_values
logger = logging.getLogger(__name__)
SUPPORTED_MODELS = (
LlamaForCausalLM,
MistralForCausalLM,
Phi3ForCausalLM,
Qwen2ForCausalLM,
Qwen3ForCausalLM,
Gemma3ForConditionalGeneration,
)
@dataclass
class BasePress:
"""
Base class for all KV cache compression methods.
This class provides the foundation for implementing various key-value cache compression
techniques. Subclasses must implement the `compress` method to define their specific
compression logic.
The compression is applied only during pre-filling (not during generation).
"""
def post_init_from_model(self, model: PreTrainedModel):
"""
Optional method to initialize press parameters from the model
"""
pass
def compress(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The core logic of the compression method.
Parameters
----------
module : nn.Module
The transformer attention layer where compression is applied.
hidden_states : torch.Tensor
Hidden states of the current layer with shape (batch_size, seq_len, hidden_dim).
These represent the input to the attention layer.
keys : torch.Tensor
Key tensors from the KV cache with shape (batch_size, num_kv_heads, seq_len, head_dim).
These are keys ready for compression.
values : torch.Tensor
Value tensors from the KV cache with shape (batch_size, num_kv_heads, seq_len, head_dim).
These are values ready for compression.
attentions : torch.Tensor
Attention weights from the layer with shape (batch_size, num_heads, seq_len, seq_len).
May be None if attention weights are not computed or needed.
kwargs : dict
Additional keyword arguments from the forward pass.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
A tuple containing the compressed keys and values tensors. The returned tensors
should have reduced sequence length dimension compared to the input tensors.
"""
raise NotImplementedError("compress method must be implemented in subclass")
def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list):
"""
Default forward hook called after the forward pass of an attention layer.
This hook automatically applies compression during the pre-filling phase by:
1. Checking if we're still in pre-filling (not generation) phase
2. Extracting keys and values from the cache (handling quantization)
3. Calling the compress method to reduce the cache size
4. Updating the cache with compressed keys and values
The hook ensures compression is only applied during pre-filling and correctly
handles both quantized and unquantized caches.
Parameters
----------
module : nn.Module
The transformer attention layer.
input : list[torch.Tensor]
Input tensors to the forward pass of the attention layer. This parameter
is provided by PyTorch's hook mechanism but not used in the default implementation.
kwargs : dict
Keyword arguments passed to the attention layer's forward method, including:
- hidden_states: Input embeddings to the attention layer
- past_key_values: The KV cache object being modified
- cache_position: Position indices indicating where we are in the sequence
- position_embeddings: RoPE embeddings if applicable
output : list
Output from the attention layer's forward pass. Contains:
- [0]: Hidden states output
- [1]: Attention weights (may be None)
Returns
-------
list
The potentially modified output from the forward pass. This
is the same as the input output, but the underlying cache has been compressed in-place.
"""
hidden_states = kwargs["hidden_states"]
cache = kwargs["past_key_values"]
cache_layer = cache.layers[module.layer_idx]
q_len = hidden_states.shape[1]
# Don't compress after pre-filling
if kwargs["cache_position"][-1] > q_len:
return output
keys, values = extract_keys_and_values(cache, module.layer_idx)
keys, values = self.compress(module, hidden_states, keys, values, output[1], kwargs)
if isinstance(cache, QuantizedCache):
cache_layer._quantized_keys = cache_layer._quantize(keys, axis=cache_layer.axis_key)
cache_layer._quantized_values = cache_layer._quantize(values, axis=cache_layer.axis_value)
cache_layer.keys = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.values = torch.zeros(0, dtype=keys.dtype, device=keys.device) # type: ignore[index]
cache_layer.cumulative_length = keys.shape[2]
else:
cache_layer.keys = keys
cache_layer.values = values
return output
@contextmanager
def __call__(self, model: PreTrainedModel) -> Generator:
"""
Context manager to apply a compression method to a model.
This method registers forward hooks on all attention layers of the model to enable
automatic KV cache compression during the pre-filling phase. The hooks are automatically
removed when exiting the context manager.
Apply this context manager during the pre-filling phase to compress the context.
Parameters
----------
model : PreTrainedModel
The transformer model to apply compression to.
Examples
--------
>>> from kvpress import KnormPress
>>> press = KnormPress(compression_ratio=0.5)
>>> with press(model):
... # Forward pass with compression applied
... outputs = model(input_ids, past_key_values=cache)
"""
if not isinstance(model, SUPPORTED_MODELS):
logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}")
if isinstance(model, Gemma3ForConditionalGeneration):
logger.warning_once("Compression in Gemma3 is only applied to layer without sliding window attention")
self.post_init_from_model(model)
hooks = []
try:
language_model = model.model.language_model if hasattr(model.model, "language_model") else model.model
for layer in language_model.layers:
if isinstance(model, Gemma3ForConditionalGeneration) and layer.self_attn.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
continue
layer.self_attn.rotary_emb = language_model.rotary_emb
hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True))
yield
finally:
for forward_hook in hooks:
forward_hook.remove()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn import functional as F
from kvpress.presses.scorer_press import ScorerPress
@dataclass
class KeyDiffPress(ScorerPress):
"""
KeyDiff: Key similarity-based KV cache compression.
Evicts tokens based on key vector similarity to average key pattern.
Identifies tokens with most similar keys to average and removes them,
keeping tokens with more distinctive key vectors.
Based on KeyDiff (https://arxiv.org/abs/2504.15364).
Note: The original press in the KeyDiff paper implements a block-wise iterative compression.
In KVPress, the iterative compression is implemented in the BlockPress class.
Therefore, to replicate the paper's implementation, please use:
`press = BlockPress(press=KeyDiffPress(compression_ratio=0.x), block_size=N)`
Parameters
----------
compression_ratio : float, default=0.0
Fraction of key-value pairs to remove during compression.
"""
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
anchor = F.normalize(keys, p=2, dim=-1).mean(dim=2, keepdim=True)
return -F.cosine_similarity(keys, anchor, dim=-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment