Commit c16d506e authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Data collection utilities for KVzap training.
This module provides functions to:
1. Load and preprocess the Nemotron dataset
2. Tokenize prompts with the KVzip repeat method
3. Extract KVzip+ scores from a model using forward hooks
"""
import pandas as pd
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations.finegrained_fp8 import FP8Linear
from transformers.models.llama.modeling_llama import repeat_kv
def load_nemotron_dataset(
tokenizer: PreTrainedTokenizerBase,
min_tokens: int = 750,
max_tokens: int = 1250,
n_train_per_subset: int = 500,
n_test_per_subset: int = 5,
) -> pd.DataFrame:
"""
Load and preprocess the Nemotron dataset for KVzap training.
The function:
1. Loads the nvidia/Nemotron-Pretraining-Dataset-sample dataset (multilingual and multi-domain)
2. Filters samples to keep only those with sequence length in [min_tokens, max_tokens]
(ensures uniform sequence length so attention weight denominators aren't influenced by length)
3. Splits into train/test with balanced sampling across subsets
Parameters
----------
tokenizer : AutoTokenizer
Tokenizer to use for computing sequence lengths
min_tokens : int, optional
Minimum number of tokens per sample, by default 750
max_tokens : int, optional
Maximum number of tokens per sample, by default 1250
n_train_per_subset : int, optional
Maximum training samples per subset, by default 500
n_test_per_subset : int, optional
Maximum test samples per subset, by default 5
Returns
-------
pd.DataFrame
DataFrame with columns: text, length, subset, split
"""
subsets = [
"Nemotron-CC-MATH",
"Nemotron-CC-High-Quality",
"Nemotron-CC-High-Quality-Synthetic",
"Nemotron-CC-Diverse-QA",
"Nemotron-CC-Translated-Diverse-QA",
"Nemotron-Synthetic-Code",
"Nemotron-SFT-Code",
"Nemotron-SFT-General",
"Nemotron-SFT-MATH",
]
# 1. Load all subsets and concatenate them
df_list = []
for subset in tqdm(subsets, desc="Loading subsets"):
df = load_dataset("nvidia/Nemotron-Pretraining-Dataset-sample", subset, split="train").to_pandas()
df["length"] = df["text"].apply(lambda x: len(tokenizer.encode(x)))
df["subset"] = subset
df_list.append(df)
df = pd.concat(df_list)
# 2. Remove the samples that are too short or too long
sub_df = df[(max_tokens > df["length"]) & (df["length"] > min_tokens)]
# 3. Split into train and test
df_test = sub_df.groupby("subset").head(n_test_per_subset)
df_test["split"] = "test"
df_train = sub_df.drop(df_test.index).groupby("subset").head(n_train_per_subset)
df_train["split"] = "train"
df = pd.concat([df_test, df_train]).reset_index(drop=True)
return df
def repeat_prompt_tokenization(
tokenizer: PreTrainedTokenizerBase, prompt: str
) -> tuple[torch.Tensor, int, int, int, int]:
"""
Tokenize a prompt using the KVzip repeat method.
Builds an extended prompt following the KVzip methodology:
```
user: <prompt>
Repeat the previous context exactly.
assistant: <prompt>
```
Parameters
----------
tokenizer : AutoTokenizer
Tokenizer to use
prompt : str
The input prompt text
Returns
-------
tuple[torch.Tensor, int, int, int, int]
- input_ids: Tokenized input tensor
- start_prompt: Start index of the original prompt
- end_prompt: End index of the original prompt
- start_repeated_prompt: Start index of the repeated prompt
- end_repeated_prompt: End index of the repeated prompt
"""
# Repeat the prompt using the chat template
prompt = prompt.strip()
messages = [
{"role": "user", "content": prompt + "\n\nRepeat the previous context exactly."},
{"role": "assistant", "content": prompt},
]
# Tokenize
prompt_with_repeat = tokenizer.apply_chat_template(messages, tokenize=False)
outputs = tokenizer(prompt_with_repeat, return_tensors="pt", return_offsets_mapping=True)
# Get the start and end indexes of the prompt and the repeated prompt
# The tokenizer might add newlines at the beginning and end of the prompt
prefix, repeat, _ = prompt_with_repeat.split(prompt)
m = outputs.offset_mapping[0, :, 0]
m = torch.cat([m, torch.tensor([len(prompt_with_repeat)])])
start_prompt = int(torch.where(m >= len(prefix))[0][0].item())
end_prompt = int(torch.where(m >= len(prefix) + len(prompt))[0][0].item())
start_repeated_prompt = int(torch.where(m >= len(prefix) + len(prompt) + len(repeat))[0][0].item())
end_repeated_prompt = int(torch.where(m >= len(prefix) + 2 * len(prompt) + len(repeat))[0][0].item())
return outputs.input_ids, start_prompt, end_prompt, start_repeated_prompt, end_repeated_prompt
class KVzapDataCollector:
"""
Collects KVzip+ importance scores from a language model using forward hooks.
Parameters
----------
model : AutoModelForCausalLM
The language model to extract scores from
tokenizer : AutoTokenizer
Tokenizer matching the model
Example
-------
>>> collector = KVzapDataCollector(model, tokenizer)
>>> X, y = collector.collect(df, n_tokens=500)
"""
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase):
self.model = model
self.tokenizer = tokenizer
# Hook communication state (replaces global variables)
self._data: list = []
self._start_prompt: int = 0
self._end_prompt: int = 0
self._start_repeated_prompt: int = 0
self._end_repeated_prompt: int = 0
def _forward_hook(self, module, input, kwargs, output):
"""
Forward hook to extract KVzip+ scores from the extended prompt.
This hook computes importance scores for each key-value pair based on:
1. Attention weights from repeated prompt tokens to original prompt tokens
2. Normalized by hidden state norms
3. Weighted by output projection norms
Results are stored in self._data as tuples of (hidden_states, scores).
"""
# Get variables
hidden_states = kwargs["hidden_states"]
values = kwargs["past_key_values"].layers[module.layer_idx].values
attn_weights = output[1]
# Initialize scores with attention weights
scores = attn_weights
# Divide by ||h|| (by row)
h_norm = torch.norm(hidden_states, dim=-1)
scores = torch.einsum("b h t i, b t -> b h t i", scores, 1 / h_norm)
# Multiply by ||WoV|| (by column)
Wo = module.o_proj.weight.transpose(0, 1)
V = repeat_kv(values, module.num_key_value_groups)
if isinstance(module.o_proj, FP8Linear):
scale = module.o_proj.weight_scale_inv.to(V.dtype).transpose(0, 1)
scale = scale.repeat_interleave(module.o_proj.block_size[0], dim=0)
scale = scale.repeat_interleave(module.o_proj.block_size[1], dim=1)
Wo = Wo.to(V.dtype) * scale
Wo = Wo.view(module.config.num_attention_heads, V.shape[-1], module.config.hidden_size)
WoV_norm = torch.einsum("h i j, b h t i -> b h t j", Wo.to(dtype=V.dtype), V).norm(dim=-1)
scores = torch.einsum("b h t i, b h i -> b h t i", scores, WoV_norm)
# Get max for each prompt across the repeated prompt tokens and the KV groups
scores = scores[
:, :, self._start_repeated_prompt : self._end_repeated_prompt, self._start_prompt : self._end_prompt
].amax(dim=2)
scores = scores.view(
scores.shape[0], module.config.num_key_value_heads, module.num_key_value_groups, scores.shape[2]
).amax(dim=2)
# Apply log
scores = torch.log(scores)
# Store the results
self._data.append((hidden_states[0, self._start_prompt : self._end_prompt, :].cpu(), scores[0].T.cpu()))
return output
def _register_hooks(self) -> list:
"""
Register forward hooks on all attention layers to extract KVzip+ scores.
Returns
-------
list
List of hook handles (can be used to remove hooks later)
"""
handles = []
for layer in self.model.model.layers: # type: ignore[attr-defined]
handle = layer.self_attn.register_forward_hook(self._forward_hook, with_kwargs=True)
handles.append(handle)
return handles
def collect(self, df: pd.DataFrame, n_tokens: int = 500) -> tuple[torch.Tensor, torch.Tensor]:
"""
Collect training data by extracting KVzip+ scores from text samples.
For each text sample in the dataset, this function:
1. Applies the KVzip repeat prompt method
2. Runs a forward pass to extract attention-based importance scores
3. Randomly samples n_tokens tokens per sample
Parameters
----------
df : pd.DataFrame
Dataset with a "text" column containing the samples
n_tokens : int, optional
Number of tokens to sample per text sample, by default 500
Returns
-------
tuple[torch.Tensor, torch.Tensor]
- X: Hidden states tensor of shape (n_samples * n_tokens, n_layers, hidden_size)
- y: Score tensor of shape (n_samples * n_tokens, n_layers, n_kv_heads)
"""
# Register hooks
handles = self._register_hooks()
try:
config = self.model.model.config # type: ignore[attr-defined]
n_layers = config.num_hidden_layers
X = torch.zeros(len(df) * n_tokens, n_layers, config.hidden_size, dtype=self.model.dtype)
y = torch.zeros(len(df) * n_tokens, n_layers, config.num_key_value_heads, dtype=self.model.dtype)
for i, text in tqdm(enumerate(df["text"]), total=len(df), desc="Extracting scores"):
# Get the scores using the repeat prompt method
tokens, self._start_prompt, self._end_prompt, self._start_repeated_prompt, self._end_repeated_prompt = (
repeat_prompt_tokenization(self.tokenizer, text)
)
self._data = []
with torch.no_grad():
self.model.model(tokens.to(self.model.device)) # type: ignore[attr-defined]
# Sample n_tokens tokens randomly
mask = torch.randperm(len(self._data[0][0]))[:n_tokens]
for layer_idx, (X_, y_) in enumerate(self._data):
X[i * n_tokens : (i + 1) * n_tokens, layer_idx] = X_[mask]
y[i * n_tokens : (i + 1) * n_tokens, layer_idx] = y_[mask]
return X, y
finally:
# Clean up hooks
for handle in handles:
handle.remove()
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import uuid
from contextlib import nullcontext
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from kvpress import DMSPress, KVzapPress
def calculate_metrics(df):
"""
Calculate metrics for the AIME25 benchmark.
"""
correct = 0
answered = 0
for _, row in df.iterrows():
try:
y_pred = str(row["predicted_answer"].split("boxed{")[-1].split("}")[0])
y_true = str(row["answer"])
score = int(y_pred == y_true)
except IndexError:
score = 0
correct += score
answered += "boxed{" in row["predicted_answer"]
return {"correct": correct, "answered": answered, "accuracy": correct / len(df), "total": len(df)}
def evaluate(
kvzap_model_type: str,
threshold: float = 0.0,
model_name: str = "Qwen/Qwen3-8B",
device: str = "cuda:0",
max_new_tokens: int = 32000,
):
"""Evaluate KVzap on the AIME25 benchmark using model.generate instead of the
KVpress pipeline in order to use sampling parameters and not greedy decoding.
Parameters
----------
kvzap_model_type : str
Model type - "mlp", "linear", or "no_press"
threshold : float, optional
Threshold for KVzap scores, by default 0.0
model_name : str, optional
HuggingFace model name, by default "Qwen/Qwen3-8B"
device : str, optional
Device to use, by default "cuda:0"
max_new_tokens : int, optional
Maximum number of tokens to generate, by default 32000
"""
# Create press
press: DMSPress | type[nullcontext[None]]
if kvzap_model_type == "no_press":
press = nullcontext
else:
press = DMSPress(
KVzapPress(model_type=kvzap_model_type),
threshold=threshold,
decoding=True,
)
# Load tokenizer, model and dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto").to(device)
df = load_dataset("alessiodevoto/aime25", split="test").to_pandas()
# Run evaluation
for idx, row in tqdm(df.iterrows(), total=len(df)):
# Tokenize question
messages = [{"role": "user", "content": row["question"]}]
tokens = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
tokens = tokens.to(model.device)
with press(model):
# Generation config from model card: https://huggingface.co/Qwen/Qwen3-32B
output_tokens = model.generate(
tokens, temperature=0.6, top_p=0.95, top_k=20, min_p=0.0, max_new_tokens=max_new_tokens
)
answer = tokenizer.decode(output_tokens[0, tokens.shape[1] :])
df.loc[idx, "predicted_answer"] = answer
if isinstance(press, DMSPress):
df.loc[idx, "compression_ratio"] = press.compression_ratio
else:
df.loc[idx, "compression_ratio"] = 0
# Save results in a new directory
dir_id = uuid.uuid4().hex
output_dir = Path(
f"results/aime25__{model_name.replace('/', '--')}__kvzap_{kvzap_model_type}__{threshold:.2f}/{dir_id}"
)
output_dir.mkdir(parents=True, exist_ok=True)
df.to_csv(output_dir / "predictions.csv", index=False)
# Calculate and save metrics
metrics = calculate_metrics(df)
with open(output_dir / "metrics.json", "w") as f:
json.dump(metrics, f)
print(f"Results saved to {output_dir}")
print(f"Metrics: {metrics}")
if __name__ == "__main__":
import fire
fire.Fire(evaluate)
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Training script for KVzap models.
This module provides functions to train KVzap models (MLP and Linear) that predict
KVzip+ importance scores from hidden states. The trained models can be used with
KVzapPress to compress the KV cache during inference.
"""
from pathlib import Path
import numpy as np
import torch
from sklearn.linear_model import Ridge
from skorch import NeuralNetRegressor
from skorch.callbacks import GradientNormClipping, LRScheduler
from skorch.dataset import ValidSplit
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config
from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel
from kvzap.data import KVzapDataCollector, load_nemotron_dataset
def train_mlp(
X: torch.Tensor,
y: torch.Tensor,
hidden_dim: int,
device: str,
max_epochs: int = 10,
lr: float = 1e-3,
batch_size: int = 512,
) -> KVzapModel:
"""
Train a two-layer MLP model to predict KVzip+ scores from hidden states.
Parameters
----------
X : torch.Tensor
Input hidden states of shape (n_samples, n_layers, hidden_size)
y : torch.Tensor
Target scores of shape (n_samples, n_layers, n_kv_heads)
hidden_dim : int
Hidden dimension of the MLP
device : str
Device to train on (e.g., "cuda:0")
max_epochs : int, optional
Maximum training epochs, by default 10
lr : float, optional
Learning rate, by default 1e-3
batch_size : int, optional
Batch size, by default 512
Returns
-------
KVzapModel
Trained MLP model
"""
mlp = KVzapModel(
KVzapConfig(input_dim=X.shape[2], hidden_dim=hidden_dim, output_dim=y.shape[2], n_modules=X.shape[1])
)
mlp.to(device, dtype=X.dtype)
net = NeuralNetRegressor(
mlp,
max_epochs=max_epochs,
criterion=nn.MSELoss(),
lr=lr,
optimizer=torch.optim.AdamW,
iterator_train__shuffle=True,
device=device,
batch_size=batch_size,
callbacks=[
LRScheduler(policy="CosineAnnealingLR", T_max=max_epochs),
GradientNormClipping(gradient_clip_value=1.0),
],
train_split=ValidSplit(0.05, random_state=42),
)
net.fit(X, y)
return mlp
def train_linear(X: torch.Tensor, y: torch.Tensor) -> KVzapModel:
"""
Train a linear model to predict KVzip+ scores from hidden states.
Parameters
----------
X : torch.Tensor
Input hidden states of shape (n_samples, n_layers, hidden_size)
y : torch.Tensor
Target scores of shape (n_samples, n_layers, n_kv_heads)
Returns
-------
KVzapModel
Trained linear model
"""
# Train a linear model for each layer
params = []
for layer_idx in tqdm(range(X.shape[1]), desc="Training linear models"):
linear = Ridge()
linear.fit(X[:, layer_idx].float(), y[:, layer_idx].float())
params.append((linear.coef_, linear.intercept_))
# Load the parameters into a KVzapModel
linear_model = KVzapModel(
KVzapConfig(input_dim=X.shape[2], hidden_dim=None, output_dim=y.shape[2], n_modules=X.shape[1])
)
for layer_idx, (W, b) in enumerate(params):
W = torch.tensor(np.atleast_2d(W), dtype=X.dtype)
b = torch.tensor(np.atleast_1d(b), dtype=X.dtype)
linear_model.layers[layer_idx].weight.data = W # type: ignore[index]
linear_model.layers[layer_idx].bias.data = b # type: ignore[index]
return linear_model
def train(
model_name: str,
output_dir: str,
# Dataset parameters
min_tokens: int = 750,
max_tokens: int = 1250,
n_train_per_subset: int = 500,
n_test_per_subset: int = 5,
n_tokens: int = 500,
fp8: bool = False,
# MLP training parameters
hidden_dim: int = 512,
max_epochs: int = 15,
lr: float = 5e-3,
batch_size: int = 512,
device: str = "cuda:0",
):
"""
Train KVzap models (MLP and linear) for a given language model.
This function:
1. Loads the model and tokenizer
2. Loads and preprocesses the Nemotron dataset
3. Extracts KVzip+ scores using the repeat prompt method
4. Trains both 2-layer MLP and linear models
5. Saves models and predictions to the output directory
Parameters
----------
model_name : str
HuggingFace model name (e.g., "Qwen/Qwen3-8B")
output_dir : str
Directory to save trained models and predictions
min_tokens : int, optional
Minimum tokens per sample, by default 750
max_tokens : int, optional
Maximum tokens per sample, by default 1250
n_train_per_subset : int, optional
Training samples per dataset subset, by default 500
n_test_per_subset : int, optional
Test samples per dataset subset, by default 5
n_tokens : int, optional
Tokens to sample per text sample, by default 500
fp8 : bool, optional
Whether to use FP8 quantization to run the model, by default False
hidden_dim : int, optional
Hidden dimension for MLP model, by default 512
max_epochs : int, optional
Maximum training epochs for MLP, by default 15
lr : float, optional
Learning rate for MLP training, by default 5e-3
batch_size : int, optional
Batch size for MLP training, by default 512
device : str, optional
Device to use for training the MLP, by default "cuda:0"
"""
# Verify input parameters
assert n_tokens < min_tokens, "n_tokens must be less than min_tokens"
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
assert output_path.is_dir() and not list(output_path.iterdir()), "Output directory is not empty"
# Load model and tokenizer
print(f"Loading model {model_name} and tokenizer")
quantization_config = FineGrainedFP8Config() if fp8 else None
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype="auto",
device_map="auto",
attn_implementation="eager",
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load dataset
print("Loading dataset")
df = load_nemotron_dataset(tokenizer, min_tokens, max_tokens, n_train_per_subset, n_test_per_subset)
print(f"Loaded {len(df)} samples (train: {(df['split'] == 'train').sum()}, test: {(df['split'] == 'test').sum()})")
# Extract scores using KVzapDataCollector
print("Extracting KVzip+ scores")
collector = KVzapDataCollector(model, tokenizer)
X, y = collector.collect(df, n_tokens)
# Free GPU memory
del model
torch.cuda.empty_cache()
# Split data into train and test
n_test = n_tokens * (df["split"] == "test").sum()
X_train, X_test = X[n_test:], X[:n_test]
y_train, y_test = y[n_test:], y[:n_test]
# Train MLP and linear models
print("Training MLP and linear models")
mlp = train_mlp(X_train, y_train, hidden_dim, device, max_epochs, lr, batch_size)
linear = train_linear(X_train, y_train)
linear.to(device)
# Evaluate and save models and predictions
print("Evaluating and saving models and predictions")
for module, name in [(mlp, "mlp"), (linear, "linear")]:
with torch.no_grad():
y_pred = module(X_test.to(device))
# Save model and predictions
module.save_pretrained(output_path / name)
np.save(output_path / name / "true.npy", y_test.cpu().float().numpy())
np.save(output_path / name / "pred.npy", y_pred.cpu().float().numpy())
print(f"Training complete. Models saved to {output_path}")
if __name__ == "__main__":
import fire
fire.Fire(train)
# Notebooks
This folder contains several Jupyter notebooks that demonstrate various features and functionalities of the kvpress package.
Below is a list of the notebooks along with a brief explanation of their content:
## [wikipedia_demo.ipynb](wikipedia_demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JNvaTKuuAHrl49dYB9-mdEH_y52Ib-NP?usp=drive_link)
This notebook introduces the kvpress package by compressing the Wikipedia article of Nvidia.
## [expected_attention.ipynb](expected_attention.ipynb)
This notebook illustrates the usage of the `ExpectedAttentionPress` class. It explains how to compute scores based on the expected attention on future positions and demonstrates the steps involved in the process.
## [new_press.ipynb](new_press.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ld6u2OnVUpGryBGDdanjjDrf6j7TD0oA?usp=drive_link)
This notebook provides an overview on how to create a new press. It explains the underlying mechanism of key-value compression and how it can be applied to transformer models.
## [per_layer_compression_demo.ipynb](per_layer_compression_demo.ipynb)
This notebook provides a demonstration of the per-layer compression feature. It shows how to improve the overall compression ratio by applying a different compression ratio to each layer of the model.
## [speed_and_memory.ipynb](speed_and_memory.ipynb)
This notebook provides a demonstration how to measure the memory and throughput gains of the kvpress package.
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "markdown",
"id": "d5b82a94ac3477fd",
"metadata": {},
"source": [
"# Using decoding compression on the AIME25 Math Dataset\n",
"\n",
"This notebook demonstrates how to compress during text generation.\n",
"We use `nvidia/OpenMath-Nemotron-7B` to solve math problems from the AIME25 dataset. For each problem, the model generates an answer in a boxed format (e.g., `\\boxed{42}`).\n",
"\n",
"To optimize memory usage during long-context generation, the notebook applies key-value cache compression during decoding.\n",
"Compression periodically reduces the cache size by keeping only the most relevant tokens, enabling efficient inference without sacrificing answer quality."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cd1c9e43c5ca1bf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data/projects/kvpress/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import pipeline\n",
"from datasets import load_dataset\n",
"\n",
"from kvpress import (\n",
" DecodingPress,\n",
" ExpectedAttentionPress,\n",
" KnormPress,\n",
" ObservedAttentionPress,\n",
" RandomPress,\n",
" SnapKVPress,\n",
" StreamingLLMPress,\n",
" TOVAPress,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "553f7d7f-7a2f-456d-a4b1-f38f3ead8767",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'problem': 'Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$',\n",
" 'answer': '70',\n",
" 'id': '0'}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load a sample\n",
"dataset = load_dataset(\"math-ai/aime25\")\n",
"sample = dataset[\"test\"][0]\n",
"sample"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "02b0af83-e71a-4129-9692-921d2bc16cb8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 64.85it/s]\n",
"Device set to use cuda:0\n"
]
}
],
"source": [
"# Load the pipeline\n",
"device = \"cuda:0\"\n",
"ckpt = \"nvidia/OpenMath-Nemotron-7B\"\n",
"attn_implementation = \"flash_attention_2\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "96e1f430-4ece-49dd-a5fb-68ae3c2ee8b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n",
"Answer: 70\n",
"Prediction: <think>\n",
"Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n",
"\n",
"First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n",
"\n",
"17_b = b + 7\n",
"\n",
"97_b = 9b + 7\n",
"\n",
"The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n",
"\n",
"Let me write that as an equation:\n",
"\n",
"(9b + 7) ÷ (b + 7) = integer.\n",
"\n",
"To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n",
"\n",
"Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n",
"\n",
"9*(b + 7) = 9b + 63\n",
"\n",
"But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n",
"\n",
"9b +7 - (9b +63) = 7 - 63 = -56\n",
"\n",
"Therefore, 9b +7 = 9*(b +7) -56\n",
"\n",
"So, (9b +7)/(b +7) = 9 - 56/(b +7)\n",
"\n",
"For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n",
"\n",
"But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n",
"\n",
"1. b > 9\n",
"\n",
"2. (b +7) divides 56.\n",
"\n",
"So, first, let's find all divisors of 56. The positive divisors of 56 are:\n",
"\n",
"1, 2, 4, 7, 8, 14, 28, 56.\n",
"\n",
"But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n",
"\n",
"Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n",
"\n",
"So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n",
"\n",
"If b +7 =28, then b=21\n",
"\n",
"If b +7=56, then b=49\n",
"\n",
"So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n",
"\n",
"Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So, 14 is too small. Similarly, 8,7, etc. So, only 28 and 56. Therefore, the possible bases are 21 and 49. So, sum is 70.\n",
"\n",
"But wait, let me verify this. Let's check for b=21:\n",
"\n",
"17 in base 21 is 1*21 +7=28\n",
"\n",
"97 in base21 is 9*21 +7=196\n",
"\n",
"196 divided by28 is 7, which is an integer. So that works.\n",
"\n",
"For b=49:\n",
"\n",
"17 in base49 is 1*49 +7=56\n",
"\n",
"97 in base49 is 9*49 +7=448\n",
"\n",
"448 divided by56 is 8, which is an integer. So that works too.\n",
"\n",
"So, both bases are valid, and their sum is 21 +49=70.\n",
"\n",
"Wait, but hold on. Let me check if there are any negative divisors. Since 56 is positive, but divisors can be negative. However, since b is a base, it must be a positive integer greater than 9. So, b +7 must be positive. Therefore, we can ignore negative divisors. So, only positive divisors. So, 28 and 56 are the only ones. Therefore, the answer is 70.\n",
"\n",
"But let me just make sure I didn't miss any divisors. Let's list all divisors again:\n",
"\n",
"1, 2, 4, 7, 8, 14, 28, 56. So, yes, only 28 and 56 are >=17. So, that's correct.\n",
"\n",
"Therefore, the sum is 21 +49=70. So, the answer is 70.\n",
"\n",
"**Final Answer**\n",
"\\boxed{70}\n",
"</think>To solve the problem, we need to find all integer bases \\( b > 9 \\) for which \\( 17_b \\) is a divisor of \\( 97_b \\).\n",
"\n",
"First, we convert the numbers from base \\( b \\) to base 10:\n",
"- \\( 17_b = 1 \\cdot b + 7 = b + 7 \\)\n",
"- \\( 97_b = 9 \\cdot b + 7 = 9b + 7 \\)\n",
"\n",
"We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n",
"\\[\n",
"\\frac{9b + 7}{b + 7} = 9 - \\frac{56}{b + 7}\n",
"\\]\n",
"For this to be an integer, \\( \\frac{56}{b + 7} \\) must be an integer. Therefore, \\( b + 7 \\) must be a divisor of 56.\n",
"\n",
"The positive divisors of 56 are:\n",
"\\[\n",
"1, 2, 4, 7, 8, 14, 28, 56\n",
"\\]\n",
"Since \\( b > 9 \\), we need \\( b + 7 \\geq 17 \\). The divisors of 56 that are greater than or equal to 17 are 28 and 56.\n",
"\n",
"Solving for \\( b \\):\n",
"- If \\( b + 7 = 28 \\), then \\( b = 21 \\)\n",
"- If \\( b + 7 = 56 \\), then \\( b = 49 \\)\n",
"\n",
"We verify these bases:\n",
"- For \\( b = 21 \\):\n",
" \\[\n",
" 17_{21} = 1 \\cdot 21 + 7 = 28\n",
" \\]\n",
" \\[\n",
" 97_{21} = 9 \\cdot 21 + 7 = 196\n",
" \\]\n",
" \\[\n",
" 196 \\div 28 = 7 \\quad \\text{(an integer)}\n",
" \\]\n",
"\n",
"- For \\( b = 49 \\):\n",
" \\[\n",
" 17_{49} = 1 \\cdot 49 + 7 = 56\n",
" \\]\n",
" \\[\n",
" 97_{49} = 9 \\cdot 49 + 7 = 448\n",
" \\]\n",
" \\[\n",
" 448 \\div 56 = 8 \\quad \\text{(an integer)}\n",
" \\]\n",
"\n",
"Both bases are valid. The sum of these bases is:\n",
"\\[\n",
"21 + 49 = 70\n",
"\\]\n",
"\n",
"Thus, the final answer is:\n",
"\\[\n",
"\\boxed{70}\n",
"\\]\n"
]
}
],
"source": [
"# Run the pipeline without compression\n",
"\n",
"question = sample[\"problem\"]\n",
"true_answer = sample[\"answer\"]\n",
"pred_answer = pipe(\" \", question=question, press=None, max_new_tokens=16_000)[\"answer\"]\n",
"\n",
"print(f\"Question: {question}\")\n",
"print(f\"Answer: {true_answer}\")\n",
"print(f\"Prediction: {pred_answer}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7f4fa366-7e62-443a-b5c8-274128fe6237",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$\n",
"Answer: 70\n",
"Prediction: <think>\n",
"Okay, so I need to find all integer bases b greater than 9 where the number 17 in base b divides the number 97 in base b. Then, sum all those bases. Hmm, let me think step by step.\n",
"\n",
"First, I should recall how numbers are represented in different bases. The number 17 in base b is equal to 1 times b plus 7 times 1, right? So that's 1*b + 7. Similarly, 97 in base b would be 9*b + 7. So, translating both numbers to base 10, we have:\n",
"\n",
"17_b = b + 7\n",
"\n",
"97_b = 9b + 7\n",
"\n",
"The problem states that 17_b must divide 97_b. So, in mathematical terms, this means that (9b + 7) divided by (b + 7) should result in an integer. So, (9b + 7) must be divisible by (b + 7). \n",
"\n",
"Let me write that as an equation:\n",
"\n",
"(9b + 7) ÷ (b + 7) = integer.\n",
"\n",
"To find when this division results in an integer, maybe I can perform the division and see what the remainder is. If the remainder is zero, then it's divisible. Let's try polynomial division or maybe manipulate the expression.\n",
"\n",
"Let me rewrite 9b + 7. Let's see, if I write 9b + 7 as 9*(b + 7) minus something. Let's compute:\n",
"\n",
"9*(b + 7) = 9b + 63\n",
"\n",
"But 9b + 7 is the original numerator. So, subtracting 9*(b + 7) from 9b +7 gives:\n",
"\n",
"9b +7 - (9b +63) = 7 - 63 = -56\n",
"\n",
"Therefore, 9b +7 = 9*(b +7) -56\n",
"\n",
"So, (9b +7)/(b +7) = 9 - 56/(b +7)\n",
"\n",
"For this to be an integer, 56/(b +7) must be an integer. Therefore, (b +7) must be a divisor of 56. \n",
"\n",
"But since b is a base greater than 9, the digits in the numbers 17_b and 97_b must be valid. In base b, the digits can go from 0 to b-1. So, in 17_b, the digits are 1 and 7. Since the base is greater than 9, 7 is a valid digit. Similarly, in 97_b, the digits are 9 and 7. So, 9 must be less than b. Therefore, the base b must be greater than 9, which is already given. So, the constraints are:\n",
"\n",
"1. b > 9\n",
"\n",
"2. (b +7) divides 56.\n",
"\n",
"So, first, let's find all divisors of 56. The positive divisors of 56 are:\n",
"\n",
"1, 2, 4, 7, 8, 14, 28, 56.\n",
"\n",
"But since b +7 must be one of these divisors, and b >9, then b +7 must be a divisor of 56 greater than 9 +7 =16. Wait, because b >9, so b >=10, so b +7 >=17. Therefore, the divisors of 56 that are greater than or equal to 17 are 28 and 56. Wait, let's check:\n",
"\n",
"Divisors of 56: 1, 2, 4, 7, 8, 14, 28, 56.\n",
"\n",
"So, divisors >=17 are 28 and 56. Therefore, b +7 can be 28 or 56. Therefore, solving for b:\n",
"\n",
"If b +7 =28, then b=21\n",
"\n",
"If b +7=56, then b=49\n",
"\n",
"So, the possible bases are 21 and 49. Then, the sum of these bases is 21 +49=70.\n",
"\n",
"Wait, but let me check if there are any other divisors. Wait, 14 is a divisor, but 14 is less than 17 (since b >=10, so b +7 >=17). So, 14 is too small. Similarly, 8,7, etc. So, only 28 and 56. Therefore, the possible bases are 21 and 49. So, sum is 70.\n",
"\n",
"But wait, let me verify this. Let's check for b=21:\n",
"\n",
"17 in base 21 is 2*21 +7=49. Wait, no, 17 in base 21 is 1*21 +7=28. 97 in base 21 is 9*21 +7=196. Then, 196 divided by 28 is 7, which is an integer. So that works.\n",
"\n",
"For b=49: 17 in base 49 is 1*49 +7=56. 97 in base 49 is 9*49 +7=448. 448 divided by 56 is 8, which is an integer. So that works too.\n",
"\n",
"So, both bases 21 and 49 are valid. Therefore, the sum is 21 +49=70.\n",
"\n",
"Wait, but hold on. Let me check if there are any other divisors. Wait, 56 is a divisor of 56, but 56 is 56. So, if b +7=56, then b=49. Which we already have. Similarly, 28 gives b=21. Are there any negative divisors? But since b is a base, it must be a positive integer greater than 9, so negative divisors don't make sense here. So, only 28 and 56. Therefore, the answer is 70.\n",
"\n",
"But wait, let me check if there's a mistake here. Let me think again. The problem says \"all bases b>9\". So, the possible divisors of 56 are 1,2,4,7,8,14,28,56. But since b +7 must be one of these, and b>9, so b +7 must be at least 17. So, the possible divisors are 28 and 56. Therefore, b=21 and 49. So, sum is 70. That seems correct.\n",
"\n",
"But let me check if there's another way to approach this. For example, maybe using modular arithmetic. Let's see.\n",
"\n",
"We have 9b +7 ≡0 mod (b +7). Let me express 9b +7 in terms of (b +7). Let's write 9b +7 =9*(b +7) - 9*7 +7=9*(b +7) -63 +7=9*(b +7) -56. Therefore, 9b +7 ≡ -56 mod (b +7). Therefore, -56 ≡0 mod (b +7), which implies that (b +7) divides 56. So, same result as before. Therefore, the same answer.\n",
"\n",
"Therefore, the possible values of b are 21 and 49, sum is 70. So, the answer is 70. Therefore, the sum is 70. So, I think that's correct.\n",
"\n",
"**Final Answer**\n",
"The sum of all such bases is \\boxed{70}.\n",
"</think>To solve the problem, we need to find all integer bases \\( b > 9 \\) such that \\( 17_b \\) divides \\( 97_b \\).\n",
"\n",
"First, we convert the numbers from base \\( b \\) to base 10:\n",
"- \\( 17_b \\) in base 10 is \\( 1 \\cdot b + 7 = b + 7 \\).\n",
"- \\( 97_b \\) in base 10 is \\( 9 \\cdot b + 7 = 9b + 7 \\).\n",
"\n",
"We need \\( 9b + 7 \\) to be divisible by \\( b + 7 \\). This can be expressed as:\n",
"\\[\n",
"9b + 7 \\equiv 0 \\pmod{b + 7}\n",
"\\]\n",
"\n",
"Rewriting \\( 9b + 7 \\) in terms of \\( b + 7 \\):\n",
"\\[\n",
"9b + 7 = 9(b + 7) - 56\n",
"\\]\n",
"Thus, we have:\n",
"\\[\n",
"9b + 7 \\equiv -56 \\pmod{b + 7}\n",
"\\]\n",
"This implies that \\( b + 7 \\) must be a divisor of 56. The positive divisors of 56 are:\n",
"\\[\n",
"1, 2, 4, 7, 8, 14, 28, 56\n",
"\\]\n",
"\n",
"Since \\( b > 9 \\), \\( b + 7 \\) must be at least 17. The divisors of 56 that are greater than or equal to 17 are:\n",
"\\[\n",
"28 \\quad \\text{and} \\quad 56\n",
"\\]\n",
"\n",
"Solving for \\( b \\):\n",
"- If \\( b + 7 = 28 \\), then \\( b = 21 \\).\n",
"- If \\( b + 7 = 56 \\), then \\( b = 49 \\).\n",
"\n",
"Thus, the valid bases are \\( b = 21 \\) and \\( b = 49 \\).\n",
"\n",
"Summing these bases:\n",
"\\[\n",
"21 + 49 = 70\n",
"\\]\n",
"\n",
"Therefore, the sum of all valid bases \\( b \\) is:\n",
"\\[\n",
"\\boxed{70}\n",
"\\]\n"
]
}
],
"source": [
"# Run the pipeline with compression\n",
"\n",
"compression_interval = 1024 # compress every compression_steps\n",
"target_size = 512 # number of tokens to keep after compression. Note that actual cache size lies in [target_size, compression_interval]\n",
"\n",
"press = DecodingPress(base_press=ExpectedAttentionPress(), compression_interval=compression_interval, target_size=target_size)\n",
"\n",
"question = sample[\"problem\"]\n",
"true_answer = sample[\"answer\"]\n",
"pred_answer = pipe(\" \", question=question, press=press, max_new_tokens=16_000)[\"answer\"]\n",
"\n",
"print(f\"Question: {question}\")\n",
"print(f\"Answer: {true_answer}\")\n",
"print(f\"Prediction: {pred_answer}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Creating a new press\n",
"\n",
"In this guide, we will walk you through the process of creating a new press."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from contextlib import contextmanager\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from transformers import pipeline\n",
"\n",
"from kvpress import BasePress, KnormPress, ScorerPress"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
"Device set to use cuda:0\n"
]
}
],
"source": [
"# Load pipeline\n",
"\n",
"device = \"cuda:0\"\n",
"ckpt = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
"attn_implementation = \"flash_attention_2\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Load data\n",
"\n",
"context = \"In this step-by-step guide, you will learn how to create a new press in kvpress !\"\n",
"question = \"\\nWhat is the purpose of this guide?\"\n",
"tokens = pipe.tokenizer(context, return_tensors=\"pt\").to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Understanding how press work"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A press registers a forward hook to each attention layer during the pre-filling phase. Immediately after the forward pass, the hook is called, and it compresses the KV cache."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n",
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cache shape w/o press: torch.Size([1, 2, 20, 128])\n",
"Cache shape w/ press: torch.Size([1, 2, 15, 128])\n",
"\n",
"The purpose of this step-by-step guide is to provide instructions on how to create a new press in kvpress. The guide is designed to help users understand the process of setting up a new press in the kvpress platform.\n"
]
}
],
"source": [
"compression_ratio = 0.25\n",
"press = KnormPress(compression_ratio)\n",
"\n",
"with torch.no_grad():\n",
" outputs_without_press = pipe.model(**tokens, output_hidden_states=True)\n",
"\n",
"with torch.no_grad(), press(pipe.model):\n",
" output_with_press = pipe.model(**tokens)\n",
"\n",
"print(f\"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}\")\n",
"print(f\"Cache shape w/ press: {output_with_press.past_key_values[0][0].shape}\\n\")\n",
"\n",
"# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).\n",
"print(pipe(context, question=question, press=press)[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Creating your own press\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1 Updating the `score` method"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.\n",
"\n",
"The arguments of the `score` method are obtained from the forward hook:\n",
"- `module`: the attention layer\n",
"- `hidden_states`: the input of the attention layer\n",
"- `keys` and `values`: the key-value pairs from the attention layer\n",
"- `attentions`: the attention weights, only available with `attn_implementation=\"eager\"`\n",
"\n",
"In this first example, we will reproduce the `KnormPress` where the score of a key-value pair is simply the opposite of the norm of the key vector."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class MyKnormPress(ScorerPress):\n",
" def score(\n",
" self,\n",
" module: nn.Module,\n",
" hidden_states: torch.Tensor,\n",
" keys: torch.Tensor,\n",
" values: torch.Tensor,\n",
" attentions: torch.Tensor,\n",
" kwargs,\n",
" ) -> torch.Tensor:\n",
"\n",
" scores = -keys.norm(dim=-1)\n",
"\n",
" # For demonstration, we show some details on the shape for the first layer\n",
" if module.layer_idx == 0:\n",
" print(f\"module: {module}\")\n",
" print(f\"Number of key value heads: {module.config.num_key_value_heads}\")\n",
" print(f\"Sequence length: {hidden_states.shape[1]}\")\n",
" print()\n",
" print(f\"hidden_states shape: {hidden_states.shape}\")\n",
" print(f\"keys shape: {keys.shape}\") # shape (bhnd)\n",
" print(f\"values shape: {values.shape}\") # shape (bhnd)\n",
" print(f\"score shape: {scores.shape}\") # shape (bhn)\n",
" print()\n",
" \n",
" return scores\n",
"\n",
"\n",
"press = MyKnormPress(compression_ratio)\n",
"print(pipe(context, question=question, press=press)[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.2 Updating the `compress` method "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.\n",
"\n",
"The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class MyStreamingLLMPress(BasePress):\n",
" n_first: int = 1\n",
" n_last: int = 8\n",
"\n",
" def compress(\n",
" self,\n",
" module: nn.Module,\n",
" hidden_states: torch.Tensor,\n",
" keys: torch.Tensor,\n",
" values: torch.Tensor,\n",
" attentions: torch.Tensor,\n",
" kwargs: dict,\n",
" ) -> tuple[torch.Tensor, torch.Tensor]:\n",
"\n",
" mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)\n",
" mask[self.n_first : -self.n_last] = False\n",
" return keys[:, :, mask, :], values[:, :, mask, :]\n",
"\n",
"\n",
"for n_last in [2, 4, 8]:\n",
" press = MyStreamingLLMPress(n_last=n_last)\n",
" print(f\"\\nn_last: {n_last}\")\n",
" print(f\"Last tokens seen by the model: {pipe.tokenizer.decode(tokens.input_ids[0, -n_last:])}\")\n",
" print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 Head-wise compression\n",
"\n",
"Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. \n",
"\n",
"To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.\n",
"\n",
"To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"compression_ratio: 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the\n",
"\n",
"compression_ratio: 0.25\n",
"Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing\n",
"\n",
"compression_ratio: 0.9\n",
"Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a\n"
]
}
],
"source": [
"@dataclass\n",
"class RandomHeadPress(BasePress):\n",
"\n",
" compression_ratio: float = 0.0\n",
"\n",
" def compress(self, module, hidden_states, keys, values, attentions, kwargs):\n",
" assert keys.shape[0] == 1, \"Only batch size 1 is supported\"\n",
" scores = torch.rand(keys.shape[:-1], device=keys.device)\n",
" mask = scores < torch.quantile(scores, self.compression_ratio)\n",
" module.masked_key_indices = torch.nonzero(mask, as_tuple=True)\n",
" \n",
" return keys, values\n",
"\n",
"for compression_ratio in [0, 0.25, 0.9]:\n",
" press = RandomHeadPress(compression_ratio)\n",
" print(f\"\\ncompression_ratio: {compression_ratio}\")\n",
" print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Contributing to kvpress"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
"- register it in the `__init__.py` file of repository\n",
"- register the press in [default_presses.py](tests/default_presses.py)\n",
"- update the README"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we showcase how to use the improve retrieval performance using per-layer compression."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from transformers import pipeline\n",
"\n",
"from kvpress import (\n",
" ExpectedAttentionPress,\n",
" KnormPress,\n",
" ObservedAttentionPress,\n",
" RandomPress,\n",
" SnapKVPress,\n",
" StreamingLLMPress,\n",
" PerLayerCompressionPress,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the pipeline and data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "14ee6cc96fce42cfb6e75b2964fbda04",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load pipeline\n",
"\n",
"device = \"cuda:0\"\n",
"ckpt = \"microsoft/Phi-3.5-mini-instruct\"\n",
"attn_implementation = \"flash_attention_2\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import datasets \n",
"\n",
"df = datasets.load_dataset(\"simonjegou/ruler\", \"4096\")[\"test\"].to_pandas()\n",
"df = df.loc[df[\"task\"] == \"niah_single_3\"].reset_index(drop=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use the pipeline with a press"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Pick a press with a compression ratio, you can run the following cells with different presses\n",
"compression_ratio = 0.3\n",
"press = ExpectedAttentionPress(compression_ratio)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.\n",
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: What is the special magic uuid for amused-quart mentioned in the provided text? \n",
"Answer: 1ff49b78-8946-4e85-b59c-de66bacfb3d0\n",
"Prediction: The special magic uuid for amused-quart mentioned in the text is: 1ff49b78-8946-4e85-b63d-a7e3c0a1c\n",
"Correctly predicted: False\n"
]
}
],
"source": [
"# Run the pipeline on a single question\n",
"idx = 0\n",
"context = df.iloc[idx][\"context\"] \n",
"question = df.iloc[idx][\"question\"] \n",
"true_answer = df.iloc[idx][\"answer\"][0]\n",
"\n",
"pred_answer = pipe(context, question=question, press=press)[\"answer\"]\n",
"\n",
"print(f\"Question: {question}\")\n",
"print(f\"Answer: {true_answer}\")\n",
"print(f\"Prediction: {pred_answer}\")\n",
"print(f\"Correctly predicted: {true_answer in pred_answer}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Apply per-layer-compression with the same overall compression ratio"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.3028125\n"
]
}
],
"source": [
"# Each layer is compressed differently, some layers have higher compression ratios, other less.\n",
"# The mean compression ratio is the same as for the original press\n",
"\n",
"PHI_35_COMPRESSION_RATIOS = [0.37, 0.3, 0.37, 0.37, 0.37, 0.37, 0.07, 0.37, 0.29, 0.37, 0.36,\n",
" 0.13, 0.37, 0.0, 0.37, 0.37, 0.37, 0.36, 0.28, 0.0, 0.09, 0.37,\n",
" 0.37, 0.37, 0.37, 0.37, 0.04, 0.37, 0.37, 0.37, 0.37, 0.37]\n",
"print(np.mean(PHI_35_COMPRESSION_RATIOS))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Per layer compression wrapper is an experimental feature and only works with flash attention. Please make sure that the model uses flash attention.\n"
]
}
],
"source": "press_per_layer = PerLayerCompressionPress(ExpectedAttentionPress(compression_ratio), PHI_35_COMPRESSION_RATIOS)"
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: What is the special magic uuid for amused-quart mentioned in the provided text? \n",
"Answer: 1ff49b78-8946-4e85-b59c-de66bacfb3d0\n",
"Prediction: The special magic uuid mentioned in the text for amused-quart is: 1ff49b78-8946-4e85-b59c-de66bacfb3d0\n",
"Correctly predicted: True\n"
]
}
],
"source": [
"pred_answer = pipe(context, question=question, press=press_per_layer)[\"answer\"]\n",
"\n",
"print(f\"Question: {question}\")\n",
"print(f\"Answer: {true_answer}\")\n",
"print(f\"Prediction: {pred_answer}\")\n",
"print(f\"Correctly predicted: {true_answer in pred_answer}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we showcase how to use the KVpress pipelines by answering questions about NVIDIA Wikipedia article."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"from bs4 import BeautifulSoup\n",
"\n",
"from transformers import pipeline\n",
"\n",
"from kvpress import ExpectedAttentionPress, KnormPress, RandomPress"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load the pipeline and data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "54a814620e6844aba021ea3376c6f823",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Device set to use cuda:0\n"
]
}
],
"source": [
"# Load pipeline\n",
"\n",
"device = \"cuda:0\"\n",
"ckpt = \"Qwen/Qwen3-8B\"\n",
"attn_implementation = \"flash_attention_2\" # use \"eager\" for ObservedAttentionPress and \"sdpa\" if you can't use \"flash_attention_2\"\n",
"pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of tokens: 11686\n"
]
}
],
"source": [
"# Load data\n",
"url = \"https://en.wikipedia.org/wiki/Nvidia\"\n",
"headers = {\"User-Agent\": \"Mozilla/5.0 (X11; Linux x86_64) \"}\n",
"content = requests.get(url, headers=headers).content\n",
"soup = BeautifulSoup(content, \"html.parser\")\n",
"context = \"\".join([p.text for p in soup.find_all(\"p\")]) + \"\\n\\n\"\n",
"tokens = pipe.tokenizer.encode(context, return_tensors=\"pt\").to(device)\n",
"print(f\"Number of tokens: {tokens.size(1)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use the pipeline with a press"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Pick a press with a compression ratio, you can run the following cells with different presses\n",
"compression_ratio = 0.5\n",
"press = ExpectedAttentionPress(compression_ratio)\n",
"# press = KnormPress(compression_ratio)\n",
"# press = RandomPress(compression_ratio)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: Complete this sentence: The Nvidia GeForce Partner Program was a ...\n",
"Answer: marketing program designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds.\n",
"Prediction: The Nvidia GeForce Partner Program was a marketing initiative designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds, but it became controversial due to allegations of anti-competitive practices.\n"
]
}
],
"source": [
"# Run the pipeline on a single question\n",
"\n",
"question = \"Complete this sentence: The Nvidia GeForce Partner Program was a ...\"\n",
"true_answer = \"marketing program designed to provide partnering companies with benefits such as public relations support, video game bundling, and marketing development funds.\"\n",
"pred_answer = pipe(context, question=question, press=press)[\"answer\"]\n",
"\n",
"print(f\"Question: {question}\")\n",
"print(f\"Answer: {true_answer}\")\n",
"print(f\"Prediction: {pred_answer}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: What happened on March 1, 2024?\n",
"Answer: Nvidia became the third company in the history of the United States to close with a market capitalization in excess of $2 trillion\n",
"Prediction: On **March 1, 2024**, **Nvidia** was ranked **#3** on **Forbes' \"Best Places to Work\" list**. This recognition highlighted the company's strong workplace culture, employee satisfaction, and\n",
"\n",
"Question: What was the unofficial company motto of Nvidia during the early days?\n",
"Answer: Our company is thirty days from going out of business\n",
"Prediction: The unofficial company motto of Nvidia during the early days was **\"A flywheel to reach large markets funding huge R&D to solve massive computational problems.\"** This motto was inspired by the concept of a flywheel, which is a device that stores rotational\n",
"\n"
]
}
],
"source": [
"# Run the pipeline on multiple questions, the context will be compressed only once\n",
"\n",
"questions = [\n",
" \"What happened on March 1, 2024?\",\n",
" \"What was the unofficial company motto of Nvidia during the early days?\",\n",
"]\n",
"\n",
"true_answers = [\n",
" \"Nvidia became the third company in the history of the United States to close with a market capitalization in excess of $2 trillion\",\n",
" \"Our company is thirty days from going out of business\",\n",
"]\n",
"\n",
"pred_answers = pipe(context, questions=questions, press=press)[\"answers\"]\n",
"for question, pred_answer, true_answer in zip(questions, pred_answers, true_answers):\n",
" print(f\"Question: {question}\")\n",
" print(f\"Answer: {true_answer}\")\n",
" print(f\"Prediction: {pred_answer}\")\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question: What is GTC ?\n",
"Answer: Nvidia's GPU Technology Conference (GTC) is a series of technical conferences held around the world.\n",
"Prediction w/o prefix: **GTC** stands for **GPU Technology Conference**, which is a major annual event organized by **NVIDIA**. It is a technical conference that\n",
"Prediction w/ prefix : Come on you don't know GTC ? Everyone knows GTC is the biggest AI conference in the world. It's held by NVIDIA, right? I mean, it's like the Super Bowl of\n"
]
}
],
"source": [
"# Use an answer prefix and limit the number of tokens in the answer\n",
"\n",
"question = \"What is GTC ?\"\n",
"true_answer = \"Nvidia's GPU Technology Conference (GTC) is a series of technical conferences held around the world.\"\n",
"answer_prefix = \"Come on you don't know GTC ? Everyone\"\n",
"max_new_tokens = 30\n",
"\n",
"pred_answer_with_prefix = pipe(context, question=question, answer_prefix=answer_prefix, press=press, max_new_tokens=max_new_tokens)[\"answer\"]\n",
"pred_answer_without_prefix = pipe(context, question=question, press=press, max_new_tokens=max_new_tokens)[\"answer\"]\n",
"\n",
"print(f\"Question: {question}\")\n",
"print(f\"Answer: {true_answer}\")\n",
"print(f\"Prediction w/o prefix: {pred_answer_without_prefix}\")\n",
"print(f\"Prediction w/ prefix : {answer_prefix + pred_answer_with_prefix}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
[project]
name = "kvpress"
version = "0.5.1"
description = "Efficiently compress the KV cache of any pretrained transformer"
authors = [
{ name = "Simon Jegou" },
{ name = "Maximilian Jeblick" },
{ name = "Alessio Devoto" },
{ name = "David Austin" },
]
requires-python = ">=3.10"
readme = "README.md"
dependencies = [
# "numpy>=2.0.0,<3",
"numpy==1.26.3",
"torch>=2.3.1,<3",
"transformers>=5.0.0",
"datasets>=2.21.0,<3",
"pandas>=2.2.2,<3",
"accelerate>=1.0.0,<2",
"requests>=2.32.3,<3",
"cachetools>=5.5.2,<6",
"fire>=0.6.0,<0.7",
]
[project.optional-dependencies]
eval = [
"rouge>=1.0.1,<2",
"nltk>=3.9.1,<4",
"tqdm>=4.66.4,<5",
"scipy>=1.13.1,<2",
"bert-score>=0.3.13,<0.4",
"jieba>=0.42.1",
"fuzzywuzzy>=0.18.0",
]
flash-attn = [
"flash-attn"
]
[dependency-groups]
dev = [
"pytest>=7.0.0,<8",
"flake8>=7.0.0,<8",
"isort>=5.13.2,<6",
"black>=24.8.0,<25",
"mypy>=1.13.0,<2",
"pytest-cov>=5.0.0,<6",
"pytest-dependency>=0.6.0,<0.7",
"pytest-html>=4.1.1, <5.0.0",
"types-pyyaml~=6.0",
"ipykernel>=6.29.4,<7",
"bs4>=0.0.2,<0.0.3",
"nvitop>=1.3.2,<2",
"matplotlib>=3.9.0,<4",
"sentencepiece>=0.2.0,<0.3",
"protobuf>=5.27.2,<6",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
no-build-isolation-package = ["flash-attn"]
[tool.black]
line-length = 120
target_version = ["py310"]
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|bundles)"
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
skip = ["venv", ".venv"]
[tool.mypy]
ignore_missing_imports = true
allow_redefinition = true
strict_optional = false
exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles|.pytest_cache|reports)"
disable_error_code = ["union-attr", "operator", "call-overload", "arg-type"]
[[tool.mypy.overrides]]
module = "kvpress.presses.base_press"
disable_error_code = ["attr-defined"]
[[tool.mypy.overrides]]
module = "kvpress.pipeline"
disable_error_code = ["attr-defined", "assignment", "override"]
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from kvpress import (
CompactorPress,
CURPress,
DuoAttentionPress,
ExpectedAttentionPress,
ExpectedAttentionStatsPress,
FastKVzipPress,
KeyDiffPress,
KnormPress,
KVzapPress,
KVzipPress,
LagKVPress,
LeverageScorePress,
NonCausalAttnPress,
PyramidKVPress,
QFilterPress,
RandomPress,
SimLayerKVPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)
from kvpress.presses.fastkvzip_press import FastKVzipGate
from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel
class TestDuoAttentionPress(DuoAttentionPress):
@staticmethod
def load_attention_pattern(model):
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_key_value_heads
return 2, 2, np.random.rand(n_layers, n_heads)
class TestKVzapPress(KVzapPress):
"""Test version of KVzapPress that creates a mock model instead of loading from HuggingFace."""
def post_init_from_model(self, model):
config = KVzapConfig(
input_dim=model.config.hidden_size,
output_dim=model.config.num_key_value_heads,
hidden_dim=None, # Use linear model for testing
n_modules=model.config.num_hidden_layers,
)
self.kvzap_model = KVzapModel(config)
class TestFastKVzipPress(FastKVzipPress):
"""Test version of FastKVzipPress that creates a mock model instead of loading from HuggingFace."""
def post_init_from_model(self, model):
if self.gates is None:
dtype = model.config.dtype
input_dim = model.config.hidden_size
ngroup = model.config.num_attention_heads // model.config.num_key_value_heads
nhead = model.config.num_key_value_heads
self.gates = []
for idx in range(model.config.num_hidden_layers):
module = FastKVzipGate(idx, input_dim, nhead, ngroup, dtype).to(model.device)
self.gates.append(module)
# contains all presses to be tested
# kwargs should be ordered easy to hard compression
default_presses = [
{"cls": TestDuoAttentionPress, "kwargs": [{"head_compression_ratio": 0.2}, {"head_compression_ratio": 0.8}]},
{"cls": KnormPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": ExpectedAttentionStatsPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": QFilterPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": SnapKVPress,
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
},
{"cls": TOVAPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": ThinKPress,
"kwargs": [
{"key_channel_compression_ratio": 0.2, "window_size": 2},
{"key_channel_compression_ratio": 0.8, "window_size": 2},
],
},
{
"cls": SimLayerKVPress,
"kwargs": [
{"lazy_threshold": 0.8, "n_initial": 1, "n_recent": 1, "n_last": 1},
{"lazy_threshold": 0.2, "n_initial": 1, "n_recent": 1, "n_last": 1},
],
},
{
"cls": PyramidKVPress,
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
},
{
"cls": LagKVPress,
"kwargs": [
{"compression_ratio": 0.5, "n_sink": 16, "lag_size": 128},
{"compression_ratio": 0.8, "n_sink": 16, "lag_size": 128},
],
},
{"cls": KeyDiffPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": KVzipPress,
"kwargs": [{"compression_ratio": 0.5, "layerwise": False}, {"compression_ratio": 0.8, "layerwise": True}],
},
{"cls": TestFastKVzipPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": CURPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{"cls": TestKVzapPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
{
"cls": CompactorPress,
"kwargs": [
{
"compression_ratio": 0.5,
"sink_size_start": 1,
"sink_size_end": 1,
"chunk_size": 256,
},
{"compression_ratio": 0.8, "sink_size_start": 0, "sink_size_end": 0, "chunk_size": 256},
],
},
{
"cls": LeverageScorePress,
"kwargs": [
{"compression_ratio": 0.8, "sketch_dimension": 48},
],
},
{
"cls": NonCausalAttnPress,
"kwargs": [
{
"compression_ratio": 0.5,
"chunk_size": 256,
},
],
},
]
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from transformers import AutoModelForCausalLM, pipeline
def get_device():
"""Helper function that returns the appropriate device (GPU if available, otherwise CPU)"""
return "cuda:0" if torch.cuda.is_available() else "cpu"
@pytest.fixture(scope="session")
def unit_test_model():
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test").eval()
return model.to(get_device())
@pytest.fixture(scope="session")
def unit_test_model_output_attention():
model = AutoModelForCausalLM.from_pretrained("MaxJeblick/llama2-0b-unit-test", attn_implementation="eager").eval()
return model.to(get_device())
@pytest.fixture(scope="session")
def danube_500m_model():
model = AutoModelForCausalLM.from_pretrained("h2oai/h2o-danube3-500m-chat").eval()
return model.to(get_device())
@pytest.fixture(scope="session")
def kv_press_unit_test_pipeline():
return pipeline(
"kv-press-text-generation",
model="maxjeblick/llama2-0b-unit-test",
device=get_device(),
)
@pytest.fixture(scope="session")
def kv_press_danube_pipeline():
return pipeline(
"kv-press-text-generation",
model="h2oai/h2o-danube3-500m-chat",
device=get_device(),
)
@pytest.fixture(scope="session")
def kv_press_adaptive_pipeline():
"""Flexible pipeline that uses GPU+flash attention if available, otherwise CPU"""
device = get_device()
ckpt = "meta-llama/Llama-3.2-1B-Instruct"
# Use flash attention only if GPU is available
model_kwargs = {}
if torch.cuda.is_available():
model_kwargs["attn_implementation"] = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
dtype="auto",
model_kwargs=model_kwargs,
)
return pipe
@pytest.fixture(scope="class")
def kv_press_llama3_1_flash_attn_pipeline():
device = "cuda:0"
ckpt = "meta-llama/Llama-3.1-8B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
@pytest.fixture(scope="class")
def kv_press_llama3_2_flash_attn_pipeline():
device = "cuda:0"
ckpt = "meta-llama/Llama-3.2-1B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
@pytest.fixture(scope="class")
def kv_press_qwen3_flash_attn_pipeline():
device = "cuda:0"
ckpt = "Qwen/Qwen3-4B-Instruct-2507"
attn_implementation = "flash_attention_2"
pipe = pipeline(
"kv-press-text-generation",
model=ckpt,
device=device,
model_kwargs={"attn_implementation": attn_implementation, "dtype": torch.bfloat16},
)
return pipe
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import datasets
import pytest
import torch
from transformers import DynamicCache, QuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
from kvpress import QFilterPress
from tests.default_presses import default_presses
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_qwen3_flash_attn_pipeline # noqa: F401
@pytest.fixture(scope="session")
def df_ruler():
df = datasets.load_dataset("simonjegou/ruler", "4096")["test"].to_pandas()
df = df.loc[df["task"] == "niah_multikey_1"].reset_index(drop=True)
return df
class TestRuler:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
@pytest.mark.parametrize("compression_ratio", [0, 0.1])
def test_ruler_is_correct(
self, kv_press_qwen3_flash_attn_pipeline, df_ruler, press_dict, cache, compression_ratio # noqa: F811
):
cls = press_dict["cls"]
kwargs = press_dict["kwargs"][0]
press = cls(**kwargs)
if not hasattr(cls, "compression_ratio"):
pytest.skip(reason="Press does not support compression_ratio")
try:
# set compression ratio to a small value for testing
# we don't want to max out compression, but rather test if cache compression works
press.compression_ratio = compression_ratio
except AttributeError:
# pytest.skip(reason="Press does not support setting compression_ratio")
pass
if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
cache = QuantizedCache(backend="quanto", config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4)
elif cache == "quantized" and not is_optimum_quanto_available():
pytest.skip("Quanto is not installed")
else:
raise ValueError(f"Unknown cache type: {cache}")
idx = 6 # qwen model passed idx 6 for all configurations
context = df_ruler.iloc[idx]["context"]
question = df_ruler.iloc[idx]["question"]
true_answer = df_ruler.iloc[idx]["answer"][0]
if isinstance(press, QFilterPress):
# QFilterPress doesn't support Qwen3 4B. Will be tested in the next test class.
return
else:
pred_answer = kv_press_qwen3_flash_attn_pipeline(context, question=question, press=press, cache=cache)[
"answer"
]
assert true_answer in pred_answer
class TestRulerForQFilter:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
@pytest.mark.parametrize("cache", ["dynamic", "quantized"])
@pytest.mark.parametrize("compression_ratio", [0, 0.1])
def test_ruler_is_correct_for_qfilter(
self, kv_press_llama3_2_flash_attn_pipeline, df_ruler, cache, compression_ratio # noqa: F811
):
cls = QFilterPress
kwargs = {"compression_ratio": 0.2}
press = cls(**kwargs)
if not hasattr(cls, "compression_ratio"):
pytest.skip(reason="Press does not support compression_ratio")
try:
# set compression ratio to a small value for testing
# we don't want to max out compression, but rather test if cache compression works
press.compression_ratio = compression_ratio
except AttributeError:
# pytest.skip(reason="Press does not support setting compression_ratio")
pass
if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
cache = QuantizedCache(backend="quanto", config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4)
elif cache == "quantized" and not is_optimum_quanto_available():
pytest.skip("Quanto is not installed")
else:
raise ValueError(f"Unknown cache type: {cache}")
idx = 0
context = df_ruler.iloc[idx]["context"]
question = df_ruler.iloc[idx]["question"]
true_answer = df_ruler.iloc[idx]["answer"][0]
pred_answer = kv_press_llama3_2_flash_attn_pipeline(context, question=question, press=press, cache=cache)[
"answer"
]
assert true_answer in pred_answer
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import torch
import torch.nn as nn
from transformers import DynamicCache
from kvpress.presses.block_press import BlockPress
from kvpress.presses.scorer_press import ScorerPress
from tests.fixtures import unit_test_model # noqa: F401
@dataclass
class HiddenStatesPress(ScorerPress): # dummy press using hidden states
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
return hidden_states.mean(-1).unsqueeze(1).expand_as(keys.norm(dim=-1))
def test_block_press_is_streaming_top_k(unit_test_model): # noqa: F811
"""
Test that BlockPress correctly applies the compression ratio and keeps the output consistent.
"""
press = HiddenStatesPress(compression_ratio=0.5)
generator = torch.Generator().manual_seed(0)
input_ids = torch.randint(0, 1024, (1, 256), generator=generator).to(unit_test_model.device)
keys_hash = []
values_hash = []
for block_size in [2, 4, 8, 128, 256]:
composed_press = BlockPress(press=press, block_size=block_size)
with composed_press(unit_test_model):
cache = DynamicCache()
unit_test_model(input_ids, past_key_values=cache).past_key_values
assert cache.get_seq_length() == 128
keys = torch.cat([cache.layers[layer_idx].keys for layer_idx in range(len(cache.layers))])
values = torch.cat([cache.layers[layer_idx].values for layer_idx in range(len(cache.layers))])
keys_hash.append(keys.sum().item())
values_hash.append(values.sum().item())
with press(unit_test_model):
cache = DynamicCache()
unit_test_model(input_ids, past_key_values=cache).past_key_values
assert cache.get_seq_length() == 128
keys = torch.cat([cache.layers[layer_idx].keys for layer_idx in range(len(cache.layers))])
values = torch.cat([cache.layers[layer_idx].values for layer_idx in range(len(cache.layers))])
keys_hash.append(keys.sum().item())
values_hash.append(values.sum().item())
keys_tensor = torch.tensor(keys_hash)
values_tensor = torch.tensor(values_hash)
assert torch.allclose(keys_tensor, keys_tensor[-1])
assert torch.allclose(values_tensor, values_tensor[-1])
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from kvpress import CompactorPress, LeverageScorePress, NonCausalAttnPress
from tests.fixtures import unit_test_model # noqa: F401
def test_compactor_press(unit_test_model): # noqa: F811
for press in [
CompactorPress(0.5, sink_size_start=0, sink_size_end=0),
CompactorPress(0.2, sink_size_start=8, sink_size_end=4),
]:
with press(unit_test_model):
input_ids = torch.arange(10, 40).to(unit_test_model.device)
unit_test_model(input_ids.unsqueeze(0), use_cache=True)
def test_leverage_press(unit_test_model): # noqa: F811
for press in [
LeverageScorePress(0.5, sketch_dimension=48),
LeverageScorePress(0.5, sketch_dimension=64),
]:
with press(unit_test_model):
input_ids = torch.arange(10, 40).to(unit_test_model.device)
unit_test_model(input_ids.unsqueeze(0), use_cache=True)
def test_non_causal_attn_press(unit_test_model): # noqa: F811
for press in [
NonCausalAttnPress(0.5, chunk_size=128),
NonCausalAttnPress(0.5, chunk_size=256),
]:
with press(unit_test_model):
input_ids = torch.arange(10, 40).to(unit_test_model.device)
unit_test_model(input_ids.unsqueeze(0), use_cache=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment