Commit be42b641 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pandas as pd
from rouge import Rouge
scorer = Rouge()
def calculate_metrics(df: pd.DataFrame) -> list[dict]:
scores = []
for index, row in df.iterrows():
score = scorer.get_scores(row["needle"].strip(), row["predicted_answer"].strip())[0]
scores.append(score)
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
import pandas as pd
from transformers import PreTrainedTokenizer
logger = logging.getLogger(__name__)
def insert_needle_in_haystack(
df: pd.DataFrame,
tokenizer: PreTrainedTokenizer,
max_context_length: int,
needle_depth: int | list[int],
context_wrapper: str = "This is a very long story book: <book> {context} </book>.",
needle_text: Optional[str] = None,
answer_prefix: Optional[str] = None,
question_text: Optional[str] = None,
) -> pd.DataFrame:
"""
Inserts the "needle" string into the "context" of each row in the DataFrame at specified depths.
A new row is created for each depth, and the DataFrame is returned with these new rows.
Parameters
----------
df : pd.DataFrame
The input DataFrame containing at least the columns "context" and "needle".
tokenizer : PreTrainedTokenizer
The tokenizer used to encode and decode the context and needle.
max_context_length : int
The maximum allowed length (in tokens) for the context, including the needle.
needle_depths : int | list[int]
A list of percentages (0-100) indicating how deep into the context the needle should be inserted.
needle_text : Optional[str]
The text to insert as the needle. If None, the first row's "needle" column is used.
answer_prefix : Optional[str]
The prefix to add to the answer. If None, the first row's "answer_prefix" column is used.
question_text : Optional[str]
The text to insert as the question. If None, the first row's "question" column is used.
max_new_tokens : int
The maximum number of new tokens to generate. If None, the first row's "max_new_tokens" column is used.
Returns
-------
pd.DataFrame
A DataFrame with the "context" column modified to include the needle, with a new row
for each specified depth.
"""
# Store the original context and needle to be reused for each depth
original_context = df["context"][0]
needle_text = needle_text or df["needle"][0]
question_text = question_text or df["question"][0]
answer_prefix = answer_prefix or df["answer_prefix"][0]
max_new_tokens = df["max_new_tokens"][0]
logger.info(f"Preparing dataset for inference. Needle: {needle_text}")
tokenized_needle = tokenizer.encode(needle_text, add_special_tokens=False)
# Account for system prompts and other overhead
context_length_limit = max_context_length - len(tokenized_needle) - 150
# Tokenize the original context once
tokenized_context = tokenizer.encode(original_context, add_special_tokens=False)[:context_length_limit]
# Initialize a list to hold the new rows
new_rows = []
needle_depth = [needle_depth] if isinstance(needle_depth, int) else needle_depth
for depth in needle_depth:
# Calculate the insertion index based on the current depth
needle_index = int(len(tokenized_context) * depth / 100)
# Create a new tokenized context with the needle inserted
new_tokenized_context = tokenized_context[:needle_index] + tokenized_needle + tokenized_context[needle_index:]
decoded_context = tokenizer.decode(new_tokenized_context, skip_special_tokens=True)
final_context = context_wrapper.format(context=decoded_context)
new_row = {
"context": final_context,
"needle": needle_text,
"needle_depth": depth,
"question": question_text,
"answer_prefix": answer_prefix,
"max_new_tokens": max_new_tokens,
}
new_rows.append(new_row)
# Create the new DataFrame from the list of rows
result_df = pd.DataFrame(new_rows)
return result_df
# RULER dataset
[RULER](https://arxiv.org/abs/2404.06654) generates synthetic examples to evaluate long-context language models with configurable sequence length (from 4k tokens to 128k tokens) and task complexity. It contains a set of 13 tasks grouped in 4 categories (needle in the haystack, question answering, multi-hop tracing and aggregation).
## Hugging Face dataset
The Hugging Face dataset for RULER can be found [here](https://huggingface.co/datasets/simonjegou/ruler). To reproduce this dataset,
1. Install the [RULER repository](https://github.com/hsiehjackson/RULER) and download the necessary data files (see 1. Download data in the README)
2. Copy paste the `generate.sh` from this repository to `$RULER/scripts`, set the `DATA_DIR` variable to your desired location of the RULER data files and run the script
3. Run `create_huggingface_dataset.py` with the correct data_dir and repo_id variables
Notes : by default we use `meta-llama/Meta-Llama-3.1-8B` as the tokenizer, while in the original RULER paper, the tokenizer depends on the model used for evaluation. Results may not be directly comparable to the original RULER benchmark. But as our focus is to evaluate the performance of a given model for different compression ratios, we believe this simplification is acceptable.
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import re
import pandas as pd
def string_match_part(preds, refs):
score = (
sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)])
/ len(preds)
* 100
)
return round(score, 2)
def string_match_all(preds, refs):
score = (
sum(
[sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]
)
/ len(preds)
* 100
)
return round(score, 2)
def calculate_metrics(df: pd.DataFrame) -> dict:
scores = {}
np_pattern = re.compile(r"[\x00-\x1f]")
df["predicted_answer"] = df["predicted_answer"].apply(lambda x: np_pattern.sub("", x.strip()).strip())
for task, df_task in df.groupby("task"):
task_category = task.split("_")[0]
metric_fn = string_match_part if task_category == "qa" else string_match_all
preds = df_task["predicted_answer"].tolist()
refs = df_task["answer"].tolist()
score = metric_fn(preds, refs)
scores[task] = {"string_match": score}
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import re
from pathlib import Path
import pandas as pd
from datasets import Dataset
# Source: https://github.com/hsiehjackson/RULER/blob/main/scripts/data/synthetic/constants.py
QUESTION_PATTERNS = {
"niah": re.compile(r"What (?:is|are all) the special magic"),
"vt": re.compile(r"Question: Find all variables that are assigned the value"),
"cwe": re.compile(r"Question: What are the 10 most common words in the above list\?"),
"fwe": re.compile(r"Question: Do not provide any explanation\."),
"qa": re.compile(r"Answer the question based on the given documents\."),
}
ANSWER_PATTERNS = {
"niah": re.compile(r"The special magic"),
"vt": re.compile(r"Answer:"),
"cwe": re.compile(r"Answer:"),
"fwe": re.compile(r"Answer:"),
"qa": re.compile(r"Answer:"),
}
# Source: https://github.com/hsiehjackson/RULER/blob/main/scripts/data/synthetic/constants.py
MAX_NEW_TOKENS = {
"niah": 128,
"vt": 30,
"cwe": 120,
"fwe": 50,
"qa": 32,
}
def get_dataframe(path):
"""
Parse the data from the provided path and return a DataFrame with the context, question, answers and task
"""
assert re.match(r".*\/\d+$", str(path)), "The path should must ends with the context length (e.g. with /4096)"
df_list = []
for task_path in Path(path).glob("**/*.jsonl"):
# Load dataframe
df = pd.read_json(task_path, lines=True)
task = task_path.parent.stem
question_pattern = QUESTION_PATTERNS[task.split("_")[0]]
answer_pattern = ANSWER_PATTERNS[task.split("_")[0]]
# Split the context and the question based on the pattern
def split_context_question(text):
idx = list(question_pattern.finditer(text))[-1].start()
context, qa = text[:idx], text[idx:]
idx = answer_pattern.search(qa).start()
question, answer = qa[:idx], qa[idx:]
return context, question, answer
df["context"], df["question"], df["answer_prefix"] = zip(*df["input"].apply(split_context_question))
df["task"] = task
df["max_new_tokens"] = MAX_NEW_TOKENS[task.split("_")[0]]
df_list.append(df)
# Concatenate all the dataframes
df = pd.concat(df_list)
df = df[["context", "question", "answer_prefix", "outputs", "task", "max_new_tokens"]]
df = df.rename(columns={"outputs": "answer"}).reset_index(drop=True)
return df
if __name__ == "__main__":
data_dir = Path("/mnt/workspace/projects/RULER/scripts/data/data/") # output of the generate.sh script
repo_id = "simonjegou/ruler"
# Loop over all the context lengths
for path in data_dir.glob("*/"):
context_length = path.stem
print(f"Processing context length {context_length}")
df = get_dataframe(path)
dataset = Dataset.from_pandas(df)
dataset.push_to_hub(repo_id=repo_id, config_name=context_length, split="test")
# The following script prepares the synthetic data benchmark for a given Hugging Face tokenizer and without template
# Before running this script, make sure you downloaded the data as explained in the README:
# cd scripts/data/synthetic/json/
# python download_paulgraham_essay.py
# bash download_qa_dataset.sh
DATA_DIR="data/data"
TOKENIZER_PATH="meta-llama/Meta-Llama-3.1-8B"
SEQ_LENGTHS=(
4096
8192
16384
)
TASKS=(
"niah_single_1"
"niah_single_2"
"niah_single_3"
"niah_multikey_1"
"niah_multikey_2"
"niah_multikey_3"
"niah_multivalue"
"niah_multiquery"
"vt"
"cwe"
"fwe"
"qa_1"
"qa_2"
)
for MAX_SEQ_LENGTH in "${SEQ_LENGTHS[@]}"; do
SAVE_DIR="${DATA_DIR}/${MAX_SEQ_LENGTH}"
for TASK in "${TASKS[@]}"; do
python data/prepare.py \
--save_dir ${SAVE_DIR} \
--benchmark synthetic \
--task ${TASK} \
--tokenizer_path ${TOKENIZER_PATH} \
--tokenizer_type hf \
--max_seq_length ${MAX_SEQ_LENGTH} \
--model_template_type base \
--num_samples 500
done
done
\ No newline at end of file
# Zero Scrolls dataset
[Zero scrolls](https://www.zero.scrolls-benchmark.com/) includes ten natural language tasks across multiple domains, including summarization, question answering, aggregated sentiment classification and information reordering.
## Hugging Face dataset
The Hugging Face dataset for Zero Scrolls can be found [here](https://huggingface.co/datasets/simonjegou/zero_scroll). To reproduce this dataset, simply run the `create_huggingface_dataset.py` script.
## Evaluation
The answer are not provided in the dataset, you will need to submit your predictions to the [Zero Scrolls](https://www.zero.scrolls-benchmark.com/) website to get the results.
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
def calculate_metrics(df):
return {}
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pandas as pd
from datasets import Dataset, load_dataset
MAX_NEW_TOKENS = {
"gov_report": 1024,
"summ_screen_fd": 512,
"qmsum": 512,
"qasper": 128,
"narrative_qa": 64,
"quality": 10,
"musique": 32,
"squality": 512,
"space_digest": 36,
"book_sum_sort": 256,
}
df_list = []
for task, max_new_tokens in MAX_NEW_TOKENS.items():
df = load_dataset("tau/zero_scrolls", task, split="test").to_pandas()
df["context"] = df.apply(lambda x: x["input"][: x["document_end_index"]], axis=1)
df["question"] = df.apply(lambda x: x["input"][x["document_end_index"] : x["query_end_index"]], axis=1)
df["answer_prefix"] = df.apply(lambda x: x["input"][x["query_end_index"] :], axis=1).str.strip()
df["answer"] = ""
df["task"] = task
df["max_new_tokens"] = max_new_tokens
df_list.append(df)
df = pd.concat(df_list)
dataset = Dataset.from_pandas(df)
dataset.push_to_hub(repo_id="zero_scrolls", split="test")
This diff is collapsed.
dataset="ruler"
data_dir="4096"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
compression_ratios=(0.1 0.25 0.5)
press_names=("expected_attention" "knorm" "streaming_llm" "snapkv")
# Check if the number of press names is less than or equal to the number of available GPUs
num_gpus=$(nvidia-smi --list-gpus | wc -l)
if [ ${#press_names[@]} -gt $num_gpus ]; then
echo "Error: The number of press names (${#press_names[@]}) exceeds the number of available GPUs ($num_gpus)"
exit 1
fi
# Iterate over press names and compression ratios
for i in "${!press_names[@]}"; do
press="${press_names[$i]}"
# Run each press_name on a different GPU in the background
(
for compression_ratio in "${compression_ratios[@]}"; do
echo "Running press_name: $press with compression_ratio: $compression_ratio on GPU cuda:$i"
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio $compression_ratio --device "cuda:$i"
done
) &
done
# Wait for all background jobs to finish
wait
echo "All evaluations completed."
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
output_dir: "./results"
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset: "ruler" # see DATASET_REGISTRY in evaluate_registry.py
data_dir: "4096" # Subdirectory of the dataset (if applicable) else leave "null"
press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py
compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0)
key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0)
threshold: null # For DMSPress
fraction: 1.0 # Fraction of dataset to evaluate (0.0 to 1.0), for quick testing
max_new_tokens: null # Maximum new tokens to generate (null = use dataset default)
max_context_length: null # Maximum context length (null = use model maximum)
query_aware: false # Whether to include question in context for query-aware compression
needle_depth: null # Depth (int or list of ints) percentage of the needle in the haystack (0 to 100), only for needle_in_haystack dataset
device: null # Device to use (null = auto-detect, "cuda:0", "cpu", etc.)
fp8: false # Whether to use FP8 quantization (FineGrainedFP8Config() from transformers)
# You can add any model kwargs here.
model_kwargs:
attn_implementation: null
dtype: "auto"
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from benchmarks.aime25.calculate_metrics import calculate_metrics as aime25_scorer
from benchmarks.infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics as longbench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from benchmarks.longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from benchmarks.loogle.calculate_metrics import calculate_metrics as loogle_scorer
from benchmarks.math500.calculate_metrics import calculate_metrics as math500_scorer
from benchmarks.needle_in_haystack.calculate_metrics import calculate_metrics as needle_in_haystack_scorer
from benchmarks.ruler.calculate_metrics import calculate_metrics as ruler_scorer
from benchmarks.zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
from kvpress import (
AdaKVPress,
BlockPress,
ChunkKVPress,
CompactorPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
CURPress,
DecodingPress,
DMSPress,
DuoAttentionPress,
ExpectedAttentionPress,
FastKVzipPress,
FinchPress,
KeyDiffPress,
KnormPress,
KVzapPress,
KVzipPress,
LagKVPress,
ObservedAttentionPress,
PyramidKVPress,
QFilterPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)
# These dictionaries define the available datasets, scorers, and KVPress methods for evaluation.
DATASET_REGISTRY = {
"loogle": "simonjegou/loogle",
"ruler": "simonjegou/ruler",
"zero_scrolls": "simonjegou/zero_scrolls",
"infinitebench": "MaxJeblick/InfiniteBench",
"longbench": "Xnhyacinth/LongBench",
"longbench-e": "Xnhyacinth/LongBench",
"longbench-v2": "simonjegou/LongBench-v2",
"needle_in_haystack": "alessiodevoto/paul_graham_essays",
# Datasets used to be used for decoding compression
"aime25": "alessiodevoto/aime25",
"math500": "alessiodevoto/math500",
}
SCORER_REGISTRY = {
"loogle": loogle_scorer,
"ruler": ruler_scorer,
"zero_scrolls": zero_scrolls_scorer,
"infinitebench": infinite_bench_scorer,
"longbench": longbench_scorer,
"longbench-e": longbench_scorer_e,
"longbench-v2": longbenchv2_scorer,
"needle_in_haystack": needle_in_haystack_scorer,
"aime25": aime25_scorer,
"math500": math500_scorer,
}
PRESS_REGISTRY = {
"adakv_snapkv": AdaKVPress(SnapKVPress()),
"block_keydiff": BlockPress(press=KeyDiffPress(), block_size=128),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"critical_adakv_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
"critical_adakv_snapkv": CriticalAdaKVPress(SnapKVPress()),
"critical_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
"critical_snapkv": CriticalKVPress(SnapKVPress()),
"cur": CURPress(),
"duo_attention": DuoAttentionPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"expected_attention": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
"fastkvzip": FastKVzipPress(),
"finch": FinchPress(),
"keydiff": KeyDiffPress(),
"kvzip": KVzipPress(),
"kvzip_plus": KVzipPress(kvzip_plus_normalization=True),
"kvzap_linear": DMSPress(press=KVzapPress(model_type="linear")),
"kvzap_mlp": DMSPress(press=KVzapPress(model_type="mlp")),
"kvzap_mlp_head": KVzapPress(model_type="mlp"),
"kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")),
"lagkv": LagKVPress(),
"knorm": KnormPress(),
"observed_attention": ObservedAttentionPress(),
"pyramidkv": PyramidKVPress(),
"qfilter": QFilterPress(),
"random": RandomPress(),
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
"tova": TOVAPress(),
"compactor": CompactorPress(),
"adakv_compactor": AdaKVPress(CompactorPress()),
"no_press": None,
"decoding_knorm": DecodingPress(base_press=KnormPress()),
"decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()),
"decoding_tova": DecodingPress(base_press=TOVAPress()),
"decoding_qfilter": DecodingPress(base_press=QFilterPress()),
"decoding_adakv_expected_attention_e2": DecodingPress(base_press=AdaKVPress(ExpectedAttentionPress(epsilon=1e-2))),
"decoding_adakv_snapkv": DecodingPress(base_press=AdaKVPress(SnapKVPress())),
"decoding_keydiff": DecodingPress(base_press=KeyDiffPress()),
}
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Script to run the leaderboard evaluation on 4 GPUs
dataset="ruler"
data_dir="4096"
model="Qwen/Qwen3-8B"
output_dir="./results_lb"
# Loop 1: presses not requiring to include the questions in the compression
press_names=("random" "knorm" "snapkv" "expected_attention" "streaming_llm" "tova" "observed_attention" "qfilter" "pyramidkv" "lagkv" "keydiff" "adakv_compactor" "cur" "duo_attention" "duo_attention_on_the_fly" "kvzip")
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name no_press --compression_ratio 0.00 --output_dir $output_dir --device "cuda:0"
for press in "${press_names[@]}"; do
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.25 --output_dir $output_dir --device "cuda:0" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.50 --output_dir $output_dir --device "cuda:1" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.75 --output_dir $output_dir --device "cuda:2" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.875 --output_dir $output_dir --device "cuda:3" &
wait
done
# Use -3, -4, -5, -6 for Qwen3-8B and -6, -7, -8, -9 for Llama-3.1-8B-Instruct
for press in "kvzap_linear" "kvzap_mlp"; do
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -3 --output_dir $output_dir --device "cuda:0" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -4 --output_dir $output_dir --device "cuda:1" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -5 --output_dir $output_dir --device "cuda:2" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -6 --output_dir $output_dir --device "cuda:3" &
wait
done
# Loop 2: presses requiring to compress questions
press_names=("snapkv" "adakv_snapkv" "finch" "chunkkv")
for press in "${press_names[@]}"; do
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.25 --output_dir $output_dir --device "cuda:0" --query_aware &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.50 --output_dir $output_dir --device "cuda:1" --query_aware &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.75 --output_dir $output_dir --device "cuda:2" --query_aware &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.875 --output_dir $output_dir --device "cuda:3" --query_aware &
wait
done
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.block_press import BlockPress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.compactor_press import CompactorPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.cur_press import CURPress
from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.dms_press import DMSPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress
from kvpress.presses.fastkvzip_press import FastKVzipPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.keydiff_press import KeyDiffPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.kvzap_press import KVzapPress
from kvpress.presses.kvzip_press import KVzipPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.leverage_press import LeverageScorePress
from kvpress.presses.non_causal_attention_press import NonCausalAttnPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
# Patch the attention functions to support head-wise compression
patch_attention_functions()
__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"CURPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
"LagKVPress",
"BlockPress",
"KeyDiffPress",
"KVzipPress",
"ExpectedAttentionStatsPress",
"DecodingPress",
"PrefillDecodingPress",
"CompactorPress",
"LeverageScorePress",
"NonCausalAttnPress",
"KVzapPress",
"DMSPress",
"FastKVzipPress",
]
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
def search_hyperplane(X, max_iter: int = 1000):
"""
Given a tensor X of shape (bsz, seq_len, head_dim), search for a hyperplane Y (bsz, head_dim)
such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
Raises a ValueError if no such hyperplane is found
Parameters
----------
X : torch.Tensor
Query tensor with shape (batch_size, seq_len, head_dim) representing
the query vectors for which we want to find a nullifying hyperplane.
max_iter : int, default=1000
Maximum number of iterations to search for the hyperplane. If no valid
hyperplane is found within this limit, a ValueError is raised.
Returns
-------
torch.Tensor
Hyperplane tensor with shape (batch_size, head_dim) scaled by -1e5 / ||Y||²
to ensure that exp(<X, Y>) ≈ 0 for all queries in X.
Raises
------
ValueError
If no valid hyperplane is found within max_iter iterations.
"""
Y = X.mean(1) # this initialization is enough for most cases
for _ in range(max_iter):
mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0
if not mask.any():
return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2
Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1)
raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0")
def attention_patch(func):
"""
Decorator to update the keys before the attention computation at the indices provided in module.masked_key_indices
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression
This solution is not optimal as it does not reduce peak memory and slightly increases runtime
Parameters
----------
func : callable
The original attention function to be patched. Should accept parameters
(module, query, key, value, attention_mask, dropout, **kwargs).
Returns
-------
callable
The wrapped attention function that supports head-wise key masking.
"""
def wrapper(module, query, key, value, attention_mask, dropout, **kwargs):
if query.shape[2] == key.shape[2]:
# Prefilling
module.masked_key_indices = None
elif getattr(module, "masked_key_indices", None) is not None:
# Decoding: build fake keys k s.t. exp(<q, k>) = 0
bsz, num_heads, seq_len, head_dim = query.shape
num_key_value_heads = key.shape[1]
num_groups = num_heads // num_key_value_heads
# Build a fake key k per key group such that for every query q, exp(<q, k>) = 0
q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim)
q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim)
k = search_hyperplane(q)
k = k.view(bsz, num_key_value_heads, head_dim)
# At indices, update the keys to the fake keys
batch_indices, head_indices, seq_indices = module.masked_key_indices
key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices]
# see https://github.com/NVIDIA/kvpress/pull/115#issuecomment-3183785597
# cu_seq_lens_k are only in kwargs if model.generate is used.
if "cu_seq_lens_k" in kwargs:
kwargs["cu_seq_lens_k"][-1] = key.shape[-2]
return func(module, query, key, value, attention_mask, dropout, **kwargs)
return wrapper
def patch_attention_functions():
"""
Apply attention patching to all transformer attention functions.
This function automatically patches all attention functions registered in
transformers' ALL_ATTENTION_FUNCTIONS to support head-wise key masking.
It enables KVPress compression methods that require head-specific masking
(like AdaKV) to work correctly during text generation.
The patching is applied globally and affects all transformer models loaded
after this function is called. It's automatically called when importing
kvpress to ensure compatibility with head-wise compression methods.
Notes
-----
This function modifies the global attention functions in the transformers
library. The modifications do not affect models that don't use head-wise compression (i.e. don't have
module.masked_key_indices).
"""
for name, func in ALL_ATTENTION_FUNCTIONS.items():
ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment