Commit be42b641 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pandas as pd
from rouge import Rouge
scorer = Rouge()
def calculate_metrics(df: pd.DataFrame) -> list[dict]:
scores = []
for index, row in df.iterrows():
score = scorer.get_scores(row["needle"].strip(), row["predicted_answer"].strip())[0]
scores.append(score)
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
import pandas as pd
from transformers import PreTrainedTokenizer
logger = logging.getLogger(__name__)
def insert_needle_in_haystack(
df: pd.DataFrame,
tokenizer: PreTrainedTokenizer,
max_context_length: int,
needle_depth: int | list[int],
context_wrapper: str = "This is a very long story book: <book> {context} </book>.",
needle_text: Optional[str] = None,
answer_prefix: Optional[str] = None,
question_text: Optional[str] = None,
) -> pd.DataFrame:
"""
Inserts the "needle" string into the "context" of each row in the DataFrame at specified depths.
A new row is created for each depth, and the DataFrame is returned with these new rows.
Parameters
----------
df : pd.DataFrame
The input DataFrame containing at least the columns "context" and "needle".
tokenizer : PreTrainedTokenizer
The tokenizer used to encode and decode the context and needle.
max_context_length : int
The maximum allowed length (in tokens) for the context, including the needle.
needle_depths : int | list[int]
A list of percentages (0-100) indicating how deep into the context the needle should be inserted.
needle_text : Optional[str]
The text to insert as the needle. If None, the first row's "needle" column is used.
answer_prefix : Optional[str]
The prefix to add to the answer. If None, the first row's "answer_prefix" column is used.
question_text : Optional[str]
The text to insert as the question. If None, the first row's "question" column is used.
max_new_tokens : int
The maximum number of new tokens to generate. If None, the first row's "max_new_tokens" column is used.
Returns
-------
pd.DataFrame
A DataFrame with the "context" column modified to include the needle, with a new row
for each specified depth.
"""
# Store the original context and needle to be reused for each depth
original_context = df["context"][0]
needle_text = needle_text or df["needle"][0]
question_text = question_text or df["question"][0]
answer_prefix = answer_prefix or df["answer_prefix"][0]
max_new_tokens = df["max_new_tokens"][0]
logger.info(f"Preparing dataset for inference. Needle: {needle_text}")
tokenized_needle = tokenizer.encode(needle_text, add_special_tokens=False)
# Account for system prompts and other overhead
context_length_limit = max_context_length - len(tokenized_needle) - 150
# Tokenize the original context once
tokenized_context = tokenizer.encode(original_context, add_special_tokens=False)[:context_length_limit]
# Initialize a list to hold the new rows
new_rows = []
needle_depth = [needle_depth] if isinstance(needle_depth, int) else needle_depth
for depth in needle_depth:
# Calculate the insertion index based on the current depth
needle_index = int(len(tokenized_context) * depth / 100)
# Create a new tokenized context with the needle inserted
new_tokenized_context = tokenized_context[:needle_index] + tokenized_needle + tokenized_context[needle_index:]
decoded_context = tokenizer.decode(new_tokenized_context, skip_special_tokens=True)
final_context = context_wrapper.format(context=decoded_context)
new_row = {
"context": final_context,
"needle": needle_text,
"needle_depth": depth,
"question": question_text,
"answer_prefix": answer_prefix,
"max_new_tokens": max_new_tokens,
}
new_rows.append(new_row)
# Create the new DataFrame from the list of rows
result_df = pd.DataFrame(new_rows)
return result_df
# RULER dataset
[RULER](https://arxiv.org/abs/2404.06654) generates synthetic examples to evaluate long-context language models with configurable sequence length (from 4k tokens to 128k tokens) and task complexity. It contains a set of 13 tasks grouped in 4 categories (needle in the haystack, question answering, multi-hop tracing and aggregation).
## Hugging Face dataset
The Hugging Face dataset for RULER can be found [here](https://huggingface.co/datasets/simonjegou/ruler). To reproduce this dataset,
1. Install the [RULER repository](https://github.com/hsiehjackson/RULER) and download the necessary data files (see 1. Download data in the README)
2. Copy paste the `generate.sh` from this repository to `$RULER/scripts`, set the `DATA_DIR` variable to your desired location of the RULER data files and run the script
3. Run `create_huggingface_dataset.py` with the correct data_dir and repo_id variables
Notes : by default we use `meta-llama/Meta-Llama-3.1-8B` as the tokenizer, while in the original RULER paper, the tokenizer depends on the model used for evaluation. Results may not be directly comparable to the original RULER benchmark. But as our focus is to evaluate the performance of a given model for different compression ratios, we believe this simplification is acceptable.
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import re
import pandas as pd
def string_match_part(preds, refs):
score = (
sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)])
/ len(preds)
* 100
)
return round(score, 2)
def string_match_all(preds, refs):
score = (
sum(
[sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]
)
/ len(preds)
* 100
)
return round(score, 2)
def calculate_metrics(df: pd.DataFrame) -> dict:
scores = {}
np_pattern = re.compile(r"[\x00-\x1f]")
df["predicted_answer"] = df["predicted_answer"].apply(lambda x: np_pattern.sub("", x.strip()).strip())
for task, df_task in df.groupby("task"):
task_category = task.split("_")[0]
metric_fn = string_match_part if task_category == "qa" else string_match_all
preds = df_task["predicted_answer"].tolist()
refs = df_task["answer"].tolist()
score = metric_fn(preds, refs)
scores[task] = {"string_match": score}
return scores
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import re
from pathlib import Path
import pandas as pd
from datasets import Dataset
# Source: https://github.com/hsiehjackson/RULER/blob/main/scripts/data/synthetic/constants.py
QUESTION_PATTERNS = {
"niah": re.compile(r"What (?:is|are all) the special magic"),
"vt": re.compile(r"Question: Find all variables that are assigned the value"),
"cwe": re.compile(r"Question: What are the 10 most common words in the above list\?"),
"fwe": re.compile(r"Question: Do not provide any explanation\."),
"qa": re.compile(r"Answer the question based on the given documents\."),
}
ANSWER_PATTERNS = {
"niah": re.compile(r"The special magic"),
"vt": re.compile(r"Answer:"),
"cwe": re.compile(r"Answer:"),
"fwe": re.compile(r"Answer:"),
"qa": re.compile(r"Answer:"),
}
# Source: https://github.com/hsiehjackson/RULER/blob/main/scripts/data/synthetic/constants.py
MAX_NEW_TOKENS = {
"niah": 128,
"vt": 30,
"cwe": 120,
"fwe": 50,
"qa": 32,
}
def get_dataframe(path):
"""
Parse the data from the provided path and return a DataFrame with the context, question, answers and task
"""
assert re.match(r".*\/\d+$", str(path)), "The path should must ends with the context length (e.g. with /4096)"
df_list = []
for task_path in Path(path).glob("**/*.jsonl"):
# Load dataframe
df = pd.read_json(task_path, lines=True)
task = task_path.parent.stem
question_pattern = QUESTION_PATTERNS[task.split("_")[0]]
answer_pattern = ANSWER_PATTERNS[task.split("_")[0]]
# Split the context and the question based on the pattern
def split_context_question(text):
idx = list(question_pattern.finditer(text))[-1].start()
context, qa = text[:idx], text[idx:]
idx = answer_pattern.search(qa).start()
question, answer = qa[:idx], qa[idx:]
return context, question, answer
df["context"], df["question"], df["answer_prefix"] = zip(*df["input"].apply(split_context_question))
df["task"] = task
df["max_new_tokens"] = MAX_NEW_TOKENS[task.split("_")[0]]
df_list.append(df)
# Concatenate all the dataframes
df = pd.concat(df_list)
df = df[["context", "question", "answer_prefix", "outputs", "task", "max_new_tokens"]]
df = df.rename(columns={"outputs": "answer"}).reset_index(drop=True)
return df
if __name__ == "__main__":
data_dir = Path("/mnt/workspace/projects/RULER/scripts/data/data/") # output of the generate.sh script
repo_id = "simonjegou/ruler"
# Loop over all the context lengths
for path in data_dir.glob("*/"):
context_length = path.stem
print(f"Processing context length {context_length}")
df = get_dataframe(path)
dataset = Dataset.from_pandas(df)
dataset.push_to_hub(repo_id=repo_id, config_name=context_length, split="test")
# The following script prepares the synthetic data benchmark for a given Hugging Face tokenizer and without template
# Before running this script, make sure you downloaded the data as explained in the README:
# cd scripts/data/synthetic/json/
# python download_paulgraham_essay.py
# bash download_qa_dataset.sh
DATA_DIR="data/data"
TOKENIZER_PATH="meta-llama/Meta-Llama-3.1-8B"
SEQ_LENGTHS=(
4096
8192
16384
)
TASKS=(
"niah_single_1"
"niah_single_2"
"niah_single_3"
"niah_multikey_1"
"niah_multikey_2"
"niah_multikey_3"
"niah_multivalue"
"niah_multiquery"
"vt"
"cwe"
"fwe"
"qa_1"
"qa_2"
)
for MAX_SEQ_LENGTH in "${SEQ_LENGTHS[@]}"; do
SAVE_DIR="${DATA_DIR}/${MAX_SEQ_LENGTH}"
for TASK in "${TASKS[@]}"; do
python data/prepare.py \
--save_dir ${SAVE_DIR} \
--benchmark synthetic \
--task ${TASK} \
--tokenizer_path ${TOKENIZER_PATH} \
--tokenizer_type hf \
--max_seq_length ${MAX_SEQ_LENGTH} \
--model_template_type base \
--num_samples 500
done
done
\ No newline at end of file
# Zero Scrolls dataset
[Zero scrolls](https://www.zero.scrolls-benchmark.com/) includes ten natural language tasks across multiple domains, including summarization, question answering, aggregated sentiment classification and information reordering.
## Hugging Face dataset
The Hugging Face dataset for Zero Scrolls can be found [here](https://huggingface.co/datasets/simonjegou/zero_scroll). To reproduce this dataset, simply run the `create_huggingface_dataset.py` script.
## Evaluation
The answer are not provided in the dataset, you will need to submit your predictions to the [Zero Scrolls](https://www.zero.scrolls-benchmark.com/) website to get the results.
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
def calculate_metrics(df):
return {}
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pandas as pd
from datasets import Dataset, load_dataset
MAX_NEW_TOKENS = {
"gov_report": 1024,
"summ_screen_fd": 512,
"qmsum": 512,
"qasper": 128,
"narrative_qa": 64,
"quality": 10,
"musique": 32,
"squality": 512,
"space_digest": 36,
"book_sum_sort": 256,
}
df_list = []
for task, max_new_tokens in MAX_NEW_TOKENS.items():
df = load_dataset("tau/zero_scrolls", task, split="test").to_pandas()
df["context"] = df.apply(lambda x: x["input"][: x["document_end_index"]], axis=1)
df["question"] = df.apply(lambda x: x["input"][x["document_end_index"] : x["query_end_index"]], axis=1)
df["answer_prefix"] = df.apply(lambda x: x["input"][x["query_end_index"] :], axis=1).str.strip()
df["answer"] = ""
df["task"] = task
df["max_new_tokens"] = max_new_tokens
df_list.append(df)
df = pd.concat(df_list)
dataset = Dataset.from_pandas(df)
dataset.push_to_hub(repo_id="zero_scrolls", split="test")
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import random
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
import pandas as pd
import torch
import yaml
from benchmarks.needle_in_haystack.utils import insert_needle_in_haystack
from datasets import load_dataset
from evaluate_registry import DATASET_REGISTRY, PRESS_REGISTRY, SCORER_REGISTRY
from fire import Fire
from tqdm import tqdm
from transformers import FineGrainedFP8Config, Pipeline, pipeline
from kvpress import (
ComposedPress,
DecodingPress,
DMSPress,
DuoAttentionPress,
FinchPress,
ObservedAttentionPress,
ScorerPress,
ThinKPress,
)
logger = logging.getLogger(__name__)
@dataclass
class EvaluationConfig:
"""Dataclass to handle all the configuration for the evaluation."""
# Core evaluation parameters
dataset: str = "ruler"
data_dir: Optional[str] = None
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
device: Optional[str] = None
press_name: str = "knorm"
compression_ratio: float = 1.0
key_channel_compression_ratio: Optional[float] = None
threshold: Optional[float] = None
# Dataset and generation parameters
fraction: float = 1.0
max_new_tokens: Optional[int] = None
max_context_length: Optional[int] = None
query_aware: bool = False
needle_depth: Optional[int] = None
# Decoding parameters
compression_interval: Optional[int] = None
target_size: Optional[int] = None
hidden_states_buffer_size: Optional[int] = None
# Output and logging
output_dir: str = "./results"
log_level: str = "INFO"
# Model-specific parameters
model_kwargs: Optional[Dict[str, Any]] = None
# Press information (will be set after press setup)
press_init_command: Optional[str] = None
# For reproducibility
seed: int = 42
# Quantization
fp8: bool = False
def __post_init__(self):
"""Validate configuration after initialization."""
# Validate dataset
assert self.dataset in DATASET_REGISTRY, f"No dataset found for {self.dataset}"
assert self.dataset in SCORER_REGISTRY, f"No scorer found for {self.dataset}"
# Validate press
assert self.press_name in PRESS_REGISTRY, f"Press '{self.press_name}' not found in PRESS_REGISTRY"
if self.press_name == "no_press":
# override compression_ratio to 0.0
logger.info("Using 'no_press' configuration. Overriding compression_ratio to 0.0")
self.compression_ratio = 0.0
# Only validate key_channel_compression_ratio if it's not None
if self.key_channel_compression_ratio is not None:
assert (
0.0 <= self.key_channel_compression_ratio <= 1.0
), f"key_channel_compression_ratio must be between 0.0 and 1.0, got {self.key_channel_compression_ratio}"
# Validate fraction
assert 0.0 < self.fraction <= 1.0, f"fraction must be between 0.0 and 1.0, got {self.fraction}"
# Initialize model_kwargs if None
if self.model_kwargs is None:
self.model_kwargs = {}
if self.dataset == "needle_in_haystack":
assert self.needle_depth is not None, "needle_depth must be set for needle_in_haystack"
assert self.max_context_length is not None, "max_context_length must be set for needle_in_haystack"
def get_results_dir(self, output_dir: Path) -> Path:
"""
Generates the unique save directory and filenames based on configuration parameters.
Parameters
----------
output_dir : Path
The output directory path
Returns
-------
Path
The path to the results directory
"""
# Build directory name components
components = [
self.dataset,
str(self.data_dir) if self.data_dir else "",
self.model.replace("/", "--"),
self.press_name,
f"{self.compression_ratio:.2f}",
]
if self.threshold is not None:
components[-1] = f"{self.threshold:.2f}"
if self.fraction < 1.0:
components.append(f"fraction{self.fraction:.3f}")
if self.max_context_length is not None:
components.append(f"max_context{self.max_context_length}")
if self.query_aware:
components.append("query_aware")
if self.key_channel_compression_ratio is not None:
components.append(f"key_channel_cr{self.key_channel_compression_ratio:.2f}")
if self.needle_depth is not None and self.dataset == "needle_in_haystack":
components.append(f"needle_depth{self.needle_depth}")
dir_name = "__".join(filter(None, components)) # Filter None/empty strings
config_dir = output_dir / dir_name
# Make sure the directory does not exist, if it does, add a number to the end
# This is to avoid overwriting results
if config_dir.exists():
i = 1
while (config_dir / f"{i}").exists():
i += 1
config_dir = config_dir / f"{i}"
config_dir.mkdir(parents=True, exist_ok=True)
return config_dir
def save_config(self, config_filename: Path):
"""
Saves the evaluation configuration to a YAML file.
"""
with open(str(config_filename), "w") as f:
yaml.dump(asdict(self), f, default_flow_style=False, indent=2, sort_keys=False)
def _load_yaml_config(path: str | Path) -> dict:
"""Loads a YAML file. Returns an empty dict if it doesn't exist."""
try:
with open(path, "r") as f:
return yaml.safe_load(f) or {}
except FileNotFoundError:
logger.warning(f"Config file not found at {path}. Using only command-line arguments and defaults.")
return {}
class EvaluationRunner:
"""
EvaluationRunner class that orchestrates the entire evaluation process.
Parameters
----------
config : EvaluationConfig
The configuration for the evaluation run.
The final output will be predictions_<config>.csv and metrics_<config>.json in the output_dir.
If the evaluation files already exist, evaluation will be skipped.
"""
def __init__(self, config: EvaluationConfig):
"""
Initializes the EvaluationRunner with a given configuration.
Parameters
----------
config : EvaluationConfig
The configuration for the evaluation run.
"""
self.config = config
self.pipeline: Optional[Pipeline] = None # Will be set by _setup_model_pipeline()
self.press: None | ScorerPress = None # Will be set by _setup_press()
self.df: Optional[pd.DataFrame] = None # Will be set by _load_dataset()
self._setup_logging()
self._setup_deterministic_seeds()
logger.info(f"Initialized EvaluationRunner with config:\n{json.dumps(asdict(self.config), indent=2)}")
def _setup_deterministic_seeds(self):
"""Set deterministic seeds for reproducible results."""
torch.manual_seed(self.config.seed)
np.random.seed(self.config.seed)
random.seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(self.config.seed)
torch.cuda.manual_seed_all(self.config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(f"Set deterministic seeds to {self.config.seed}")
def _setup_logging(self):
"""Configures the logging level based on the config."""
log_level = self.config.log_level.upper()
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logger.addHandler(handler)
logger.setLevel(log_level)
def _setup_directories(self) -> Path:
"""
Creates the output directory for saving results if it doesn't exist.
Returns
-------
Path
The path to the output directory.
"""
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Output directory set to: {output_dir}")
return output_dir
def _setup_press(self):
"""
Initializes the KVPress instance and applies compression ratios based on its type.
"""
press_name = self.config.press_name
compression_ratio = self.config.compression_ratio
key_channel_compression_ratio = self.config.key_channel_compression_ratio
press = PRESS_REGISTRY[press_name]
# Apply compression ratios based on press type
if isinstance(press, DuoAttentionPress):
press.head_compression_ratio = compression_ratio
logger.info(f"Set DuoAttentionPress head_compression_ratio to {compression_ratio}")
elif isinstance(press, DMSPress):
assert self.config.threshold is not None, "threshold must be set for DMSPress"
press.threshold = self.config.threshold
logger.info(f"Set DMSPress threshold to {press.threshold}")
elif isinstance(press, ComposedPress):
for ps in press.presses:
if isinstance(ps, ThinKPress):
assert (
key_channel_compression_ratio is not None
), "key_channel_compression_ratio must be set for ThinKPress in ComposedPress"
ps.key_channel_compression_ratio = key_channel_compression_ratio
logger.info(f"Set ComposedPress key_channel_compression_ratio to {key_channel_compression_ratio}")
else:
# Check if compression_ratio attribute exists before setting
if hasattr(ps, "compression_ratio"):
ps.compression_ratio = compression_ratio
logger.info(f"Set ComposedPress compression_ratio to {compression_ratio}")
else:
logger.warning(
f"ComposedPress component {ps.__class__.__name__} has no 'compression_ratio' attribute."
)
elif isinstance(press, ThinKPress):
assert key_channel_compression_ratio is not None, "key_channel_compression_ratio must be set for ThinKPress"
press.key_channel_compression_ratio = key_channel_compression_ratio
logger.info(f"Set ThinKPress key_channel_compression_ratio to {key_channel_compression_ratio}")
elif isinstance(press, DecodingPress):
press.compression_interval = self.config.compression_interval or press.compression_interval
press.target_size = self.config.target_size or press.target_size
press.hidden_states_buffer_size = self.config.hidden_states_buffer_size or press.hidden_states_buffer_size
logger.info(
f"Set DecodingPress compression_interval to {self.config.compression_interval}, target_size to {self.config.target_size}, hidden_states_buffer_size to {self.config.hidden_states_buffer_size}"
)
else:
if hasattr(press, "compression_ratio"):
press.compression_ratio = compression_ratio
logger.info(f"Set {press.__class__.__name__} compression_ratio to {compression_ratio}")
else:
logger.warning(
f"Press {press.__class__.__name__} has no 'compression_ratio' attribute. This is expected is you set `no_press`."
)
self.press = press
# Set the press info in the config for saving to YAML
self.config.press_init_command = str(press)
logger.info(f"KV Press '{press_name}' setup.")
def _load_and_prepare_dataset(self):
"""
Loads the dataset specified in the config and applies sampling/filtering.
"""
dataset_name = self.config.dataset
data_dir = str(self.config.data_dir) if self.config.data_dir else None
fraction = self.config.fraction
logger.info(f"Loading dataset: {DATASET_REGISTRY[dataset_name]} (data_dir: {data_dir})")
df = load_dataset(DATASET_REGISTRY[dataset_name], data_dir=data_dir, split="test").to_pandas()
if fraction < 1.0:
original_len = len(df)
df = df.sample(frac=fraction, random_state=self.config.seed)
logger.info(f"Sampled {len(df)} samples ({fraction:.2f}) from original {original_len} samples.")
logger.info(f"Dataset loaded with {len(df)} entries.")
# if we have needle in a haystack, we need to insert it in the context
if self.config.dataset == "needle_in_haystack":
df = insert_needle_in_haystack(
df, self.pipeline.tokenizer, self.config.max_context_length, self.config.needle_depth
)
if isinstance(self.press, FinchPress):
if not self.config.query_aware:
logger.error("FinchPress requires 'query_aware' to be set to True.")
raise ValueError("FinchPress requires query_aware to be set to True")
# FinchPress uses a delimiter token to separate context and question
# So we need to update the tokenizer and the model embeddings.
logger.info("FinchPress detected, updating model and tokenizer with delimiter token.")
self.press.update_model_and_tokenizer(self.pipeline.model, self.pipeline.tokenizer) # type: ignore[attr-defined]
df["context"] = df["context"] + self.press.delimiter_token # type: ignore[attr-defined, index]
if self.config.query_aware:
logger.info("Query-aware compression: including question in context for compression.")
df["context"] = df["context"] + df["question"] # type: ignore[index]
df["question"] = "" # type: ignore[index]
self.df = df
logger.info(f"Dataset processed with {len(self.df)} entries.")
def _setup_model_pipeline(self):
model_name = self.config.model
device = self.config.device
if device is None:
device = "auto" if torch.cuda.is_available() else "cpu"
logger.info(f"No device specified, auto-detected device: {device}")
model_kwargs = self.config.model_kwargs or {}
if self.config.fp8:
model_kwargs["quantization_config"] = FineGrainedFP8Config()
logger.info("FP8 quantization enabled.")
if isinstance(self.press, ObservedAttentionPress):
model_kwargs["attn_implementation"] = "eager"
logger.info("ObservedAttentionPress detected, setting attn_implementation to 'eager'.")
else:
try:
import flash_attn # noqa: F401
model_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Flash Attention 2 detected, setting attn_implementation to 'flash_attention_2'.")
except ImportError:
logger.info("Flash Attention 2 not available, using default attn_implementation.")
pass
logger.info(f"Loading model pipeline for: {model_name} on device: {device} with model_kwargs: {model_kwargs}")
pipeline_kwargs = {
"model": model_name,
"model_kwargs": model_kwargs,
"trust_remote_code": True,
}
if device == "auto":
pipeline_kwargs["device_map"] = "auto"
else:
pipeline_kwargs["device"] = device
self.pipeline = pipeline("kv-press-text-generation", **pipeline_kwargs)
self.pipeline.model.eval()
logger.info("Model pipeline loaded.")
@torch.inference_mode()
def _run_inference(self):
"""
Executes the inference process on the prepared dataset using the model pipeline.
"""
self.df["predicted_answer"] = None # type: ignore[index]
if isinstance(self.press, DecodingPress):
logger.info("DecodingPress detected, running inference for each context-question pair.")
for index, row in tqdm(self.df.iterrows(), total=len(self.df), desc="Running Inference"):
context = row["context"]
question = row["question"]
answer_prefix = row["answer_prefix"]
max_new_tokens = self.config.max_new_tokens or row["max_new_tokens"]
output = self.pipeline(
context,
question=question,
answer_prefix=answer_prefix,
press=self.press,
max_new_tokens=max_new_tokens,
max_context_length=self.config.max_context_length,
)
self.df.loc[index, "predicted_answer"] = output["answer"] # type: ignore[union-attr]
torch.cuda.empty_cache() # Clear CUDA cache to free up memory
else:
df_context_grouped = self.df.groupby("context") # type: ignore[union-attr]
assert all(
df_context_grouped["answer_prefix"].nunique() == 1
), "Inconsistent 'answer_prefix' within the same context group detected."
logger.info("Starting inference...")
for context, df_group in tqdm(
df_context_grouped, total=self.df["context"].nunique(), desc="Running Inference"
): # type: ignore[union-attr]
questions = df_group["question"].to_list()
# Use max_new_tokens from config, or fallback to dataset's default for the task
max_new_tokens = self.config.max_new_tokens or df_group["max_new_tokens"].iloc[0]
answer_prefix = df_group["answer_prefix"].iloc[0]
output = self.pipeline( # type: ignore[misc]
context,
questions=questions,
answer_prefix=answer_prefix,
press=self.press,
max_new_tokens=max_new_tokens,
max_context_length=self.config.max_context_length,
)
self.df.loc[df_group.index, "predicted_answer"] = output["answers"] # type: ignore[union-attr]
# Store the actual compression ratio used (if the press has one)
self.df.loc[df_group.index, "compression_ratio"] = (
self.press.compression_ratio if self.press is not None else 0.0 # type: ignore[attr-defined]
) # type: ignore[union-attr, attr-defined]
torch.cuda.empty_cache() # Clear CUDA cache to free up memory
logger.info("Inference completed.")
def _save_results(self, save_filename: Path):
"""
Saves the predicted answers and compression ratios to a CSV file.
Parameters
----------
save_filename : Path
The full path including filename to save the CSV.
"""
if save_filename.exists():
logger.warning(f"Results CSV already exists at {save_filename}. Overwriting.")
self.df[list(set(self.df.columns) - set(["context"]))].to_csv(
str(save_filename), index=False
) # type: ignore[index]
logger.info(f"Results saved to {save_filename}")
def _calculate_and_save_metrics(self, save_filename: Path):
"""
Calculates evaluation metrics and saves them to a JSON file.
Parameters
----------
save_filename : Path
The base filename (e.g., CSV path) to derive the JSON path from.
"""
dataset_name = self.config.dataset
scorer = SCORER_REGISTRY[dataset_name]
logger.info(f"Calculating metrics for dataset: {dataset_name}")
metrics = scorer(self.df) # type: ignore[call-arg]
with open(str(save_filename), "w") as f:
json.dump(metrics, f, indent=4) # Pretty print JSON
logger.info(f"Metrics saved to {save_filename}")
logger.info(f"Metrics:\n{json.dumps(metrics, indent=2)}")
def run_evaluation(self):
"""
Orchestrates the entire evaluation process.
"""
logger.info("Starting evaluation run...")
output_dir = self._setup_directories()
results_dir = self.config.get_results_dir(output_dir)
predictions_filename = results_dir / "predictions.csv"
metrics_filename = results_dir / "metrics.json"
config_filename = results_dir / "config.yaml"
if predictions_filename.exists() and metrics_filename.exists():
logger.info(
f"Evaluation files already exist at \n {predictions_filename} \n {metrics_filename}.\nSkipping..."
)
return
self._setup_press()
self._setup_model_pipeline()
self._load_and_prepare_dataset()
self._run_inference()
self._save_results(predictions_filename)
self._calculate_and_save_metrics(metrics_filename)
self.config.save_config(config_filename)
logger.info("Evaluation run completed successfully.")
# --- Command-Line Interface ---
class CliEntryPoint:
"""
CLI entry point for building configuration and running the evaluation.
This class provides a command-line interface for running KVPress evaluations.
Configuration can be specified via:
1. YAML config file (default: "./evaluate_config.yaml")
2. Command-line arguments (highest priority)
"""
def __call__(self, config_file: Optional[str] = "./evaluate_config.yaml", **cli_overrides):
"""
Builds the configuration and runs the evaluation.
Configuration is built by layering:
1. Default values from EvaluationConfig
2. Values from YAML config file
3. Command-line arguments (highest priority)
"""
# 1. Start with dataclass defaults.
final_args = asdict(EvaluationConfig())
# 2. Layer YAML values on top.
yaml_config = _load_yaml_config(config_file)
final_args.update(yaml_config)
# 3. Layer CLI arguments on top (highest priority).
# Filter out None values from CLI overrides
cli_args = {k: v for k, v in cli_overrides.items() if v is not None}
final_args.update(cli_args)
# 4. Create and validate the final config object.
try:
config = EvaluationConfig(**final_args)
except TypeError as e:
# Provide a user-friendly error for bad arguments.
print(f"Error: Invalid configuration argument provided. {e}", file=sys.stderr)
sys.exit(1)
runner = EvaluationRunner(config)
runner.run_evaluation()
if __name__ == "__main__":
Fire(CliEntryPoint)
dataset="ruler"
data_dir="4096"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
compression_ratios=(0.1 0.25 0.5)
press_names=("expected_attention" "knorm" "streaming_llm" "snapkv")
# Check if the number of press names is less than or equal to the number of available GPUs
num_gpus=$(nvidia-smi --list-gpus | wc -l)
if [ ${#press_names[@]} -gt $num_gpus ]; then
echo "Error: The number of press names (${#press_names[@]}) exceeds the number of available GPUs ($num_gpus)"
exit 1
fi
# Iterate over press names and compression ratios
for i in "${!press_names[@]}"; do
press="${press_names[$i]}"
# Run each press_name on a different GPU in the background
(
for compression_ratio in "${compression_ratios[@]}"; do
echo "Running press_name: $press with compression_ratio: $compression_ratio on GPU cuda:$i"
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio $compression_ratio --device "cuda:$i"
done
) &
done
# Wait for all background jobs to finish
wait
echo "All evaluations completed."
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
output_dir: "./results"
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset: "ruler" # see DATASET_REGISTRY in evaluate_registry.py
data_dir: "4096" # Subdirectory of the dataset (if applicable) else leave "null"
press_name: "knorm" # see PRESS_REGISTRY in evaluate_registry.py
compression_ratio: 0.5 # Compression ratio for the press (0.0 to 1.0)
key_channel_compression_ratio: null # For ThinKPress and ComposedPress (0.0 to 1.0)
threshold: null # For DMSPress
fraction: 1.0 # Fraction of dataset to evaluate (0.0 to 1.0), for quick testing
max_new_tokens: null # Maximum new tokens to generate (null = use dataset default)
max_context_length: null # Maximum context length (null = use model maximum)
query_aware: false # Whether to include question in context for query-aware compression
needle_depth: null # Depth (int or list of ints) percentage of the needle in the haystack (0 to 100), only for needle_in_haystack dataset
device: null # Device to use (null = auto-detect, "cuda:0", "cpu", etc.)
fp8: false # Whether to use FP8 quantization (FineGrainedFP8Config() from transformers)
# You can add any model kwargs here.
model_kwargs:
attn_implementation: null
dtype: "auto"
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from benchmarks.aime25.calculate_metrics import calculate_metrics as aime25_scorer
from benchmarks.infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics as longbench_scorer
from benchmarks.longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from benchmarks.longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from benchmarks.loogle.calculate_metrics import calculate_metrics as loogle_scorer
from benchmarks.math500.calculate_metrics import calculate_metrics as math500_scorer
from benchmarks.needle_in_haystack.calculate_metrics import calculate_metrics as needle_in_haystack_scorer
from benchmarks.ruler.calculate_metrics import calculate_metrics as ruler_scorer
from benchmarks.zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer
from kvpress import (
AdaKVPress,
BlockPress,
ChunkKVPress,
CompactorPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
CURPress,
DecodingPress,
DMSPress,
DuoAttentionPress,
ExpectedAttentionPress,
FastKVzipPress,
FinchPress,
KeyDiffPress,
KnormPress,
KVzapPress,
KVzipPress,
LagKVPress,
ObservedAttentionPress,
PyramidKVPress,
QFilterPress,
RandomPress,
SnapKVPress,
StreamingLLMPress,
ThinKPress,
TOVAPress,
)
# These dictionaries define the available datasets, scorers, and KVPress methods for evaluation.
DATASET_REGISTRY = {
"loogle": "simonjegou/loogle",
"ruler": "simonjegou/ruler",
"zero_scrolls": "simonjegou/zero_scrolls",
"infinitebench": "MaxJeblick/InfiniteBench",
"longbench": "Xnhyacinth/LongBench",
"longbench-e": "Xnhyacinth/LongBench",
"longbench-v2": "simonjegou/LongBench-v2",
"needle_in_haystack": "alessiodevoto/paul_graham_essays",
# Datasets used to be used for decoding compression
"aime25": "alessiodevoto/aime25",
"math500": "alessiodevoto/math500",
}
SCORER_REGISTRY = {
"loogle": loogle_scorer,
"ruler": ruler_scorer,
"zero_scrolls": zero_scrolls_scorer,
"infinitebench": infinite_bench_scorer,
"longbench": longbench_scorer,
"longbench-e": longbench_scorer_e,
"longbench-v2": longbenchv2_scorer,
"needle_in_haystack": needle_in_haystack_scorer,
"aime25": aime25_scorer,
"math500": math500_scorer,
}
PRESS_REGISTRY = {
"adakv_snapkv": AdaKVPress(SnapKVPress()),
"block_keydiff": BlockPress(press=KeyDiffPress(), block_size=128),
"chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
"critical_adakv_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
"critical_adakv_snapkv": CriticalAdaKVPress(SnapKVPress()),
"critical_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
"critical_snapkv": CriticalKVPress(SnapKVPress()),
"cur": CURPress(),
"duo_attention": DuoAttentionPress(),
"duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
"expected_attention": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
"fastkvzip": FastKVzipPress(),
"finch": FinchPress(),
"keydiff": KeyDiffPress(),
"kvzip": KVzipPress(),
"kvzip_plus": KVzipPress(kvzip_plus_normalization=True),
"kvzap_linear": DMSPress(press=KVzapPress(model_type="linear")),
"kvzap_mlp": DMSPress(press=KVzapPress(model_type="mlp")),
"kvzap_mlp_head": KVzapPress(model_type="mlp"),
"kvzap_mlp_layer": AdaKVPress(KVzapPress(model_type="mlp")),
"lagkv": LagKVPress(),
"knorm": KnormPress(),
"observed_attention": ObservedAttentionPress(),
"pyramidkv": PyramidKVPress(),
"qfilter": QFilterPress(),
"random": RandomPress(),
"snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
"snapkv": SnapKVPress(),
"streaming_llm": StreamingLLMPress(),
"think": ThinKPress(),
"tova": TOVAPress(),
"compactor": CompactorPress(),
"adakv_compactor": AdaKVPress(CompactorPress()),
"no_press": None,
"decoding_knorm": DecodingPress(base_press=KnormPress()),
"decoding_streaming_llm": DecodingPress(base_press=StreamingLLMPress()),
"decoding_tova": DecodingPress(base_press=TOVAPress()),
"decoding_qfilter": DecodingPress(base_press=QFilterPress()),
"decoding_adakv_expected_attention_e2": DecodingPress(base_press=AdaKVPress(ExpectedAttentionPress(epsilon=1e-2))),
"decoding_adakv_snapkv": DecodingPress(base_press=AdaKVPress(SnapKVPress())),
"decoding_keydiff": DecodingPress(base_press=KeyDiffPress()),
}
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Script to run the leaderboard evaluation on 4 GPUs
dataset="ruler"
data_dir="4096"
model="Qwen/Qwen3-8B"
output_dir="./results_lb"
# Loop 1: presses not requiring to include the questions in the compression
press_names=("random" "knorm" "snapkv" "expected_attention" "streaming_llm" "tova" "observed_attention" "qfilter" "pyramidkv" "lagkv" "keydiff" "adakv_compactor" "cur" "duo_attention" "duo_attention_on_the_fly" "kvzip")
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name no_press --compression_ratio 0.00 --output_dir $output_dir --device "cuda:0"
for press in "${press_names[@]}"; do
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.25 --output_dir $output_dir --device "cuda:0" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.50 --output_dir $output_dir --device "cuda:1" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.75 --output_dir $output_dir --device "cuda:2" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.875 --output_dir $output_dir --device "cuda:3" &
wait
done
# Use -3, -4, -5, -6 for Qwen3-8B and -6, -7, -8, -9 for Llama-3.1-8B-Instruct
for press in "kvzap_linear" "kvzap_mlp"; do
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -3 --output_dir $output_dir --device "cuda:0" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -4 --output_dir $output_dir --device "cuda:1" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -5 --output_dir $output_dir --device "cuda:2" &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --threshold -6 --output_dir $output_dir --device "cuda:3" &
wait
done
# Loop 2: presses requiring to compress questions
press_names=("snapkv" "adakv_snapkv" "finch" "chunkkv")
for press in "${press_names[@]}"; do
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.25 --output_dir $output_dir --device "cuda:0" --query_aware &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.50 --output_dir $output_dir --device "cuda:1" --query_aware &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.75 --output_dir $output_dir --device "cuda:2" --query_aware &
python evaluate.py --dataset $dataset --data_dir $data_dir --model $model --press_name $press --compression_ratio 0.875 --output_dir $output_dir --device "cuda:3" --query_aware &
wait
done
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.attention_patch import patch_attention_functions
from kvpress.pipeline import KVPressTextGenerationPipeline
from kvpress.presses.adakv_press import AdaKVPress
from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress
from kvpress.presses.block_press import BlockPress
from kvpress.presses.chunk_press import ChunkPress
from kvpress.presses.chunkkv_press import ChunkKVPress
from kvpress.presses.compactor_press import CompactorPress
from kvpress.presses.composed_press import ComposedPress
from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress
from kvpress.presses.cur_press import CURPress
from kvpress.presses.decoding_press import DecodingPress
from kvpress.presses.dms_press import DMSPress
from kvpress.presses.duo_attention_press import DuoAttentionPress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress
from kvpress.presses.fastkvzip_press import FastKVzipPress
from kvpress.presses.finch_press import FinchPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress
from kvpress.presses.keydiff_press import KeyDiffPress
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.kvzap_press import KVzapPress
from kvpress.presses.kvzip_press import KVzipPress
from kvpress.presses.lagkv_press import LagKVPress
from kvpress.presses.leverage_press import LeverageScorePress
from kvpress.presses.non_causal_attention_press import NonCausalAttnPress
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from kvpress.presses.prefill_decoding_press import PrefillDecodingPress
from kvpress.presses.pyramidkv_press import PyramidKVPress
from kvpress.presses.qfilter_press import QFilterPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.scorer_press import ScorerPress
from kvpress.presses.simlayerkv_press import SimLayerKVPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.think_press import ThinKPress
from kvpress.presses.tova_press import TOVAPress
# Patch the attention functions to support head-wise compression
patch_attention_functions()
__all__ = [
"CriticalAdaKVPress",
"CriticalKVPress",
"CURPress",
"AdaKVPress",
"BasePress",
"ComposedPress",
"ScorerPress",
"ExpectedAttentionPress",
"KnormPress",
"ObservedAttentionPress",
"RandomPress",
"SimLayerKVPress",
"SnapKVPress",
"StreamingLLMPress",
"ThinKPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
"PerLayerCompressionPress",
"KeyRerotationPress",
"ChunkPress",
"DuoAttentionPress",
"ChunkKVPress",
"QFilterPress",
"PyramidKVPress",
"FinchPress",
"LagKVPress",
"BlockPress",
"KeyDiffPress",
"KVzipPress",
"ExpectedAttentionStatsPress",
"DecodingPress",
"PrefillDecodingPress",
"CompactorPress",
"LeverageScorePress",
"NonCausalAttnPress",
"KVzapPress",
"DMSPress",
"FastKVzipPress",
]
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
def search_hyperplane(X, max_iter: int = 1000):
"""
Given a tensor X of shape (bsz, seq_len, head_dim), search for a hyperplane Y (bsz, head_dim)
such that for every i, <X[:, i], Y> <= 0. Returns - 1e5 * Y / ||Y|| ** 2 to ensure exp(<X, Y>) = 0
Raises a ValueError if no such hyperplane is found
Parameters
----------
X : torch.Tensor
Query tensor with shape (batch_size, seq_len, head_dim) representing
the query vectors for which we want to find a nullifying hyperplane.
max_iter : int, default=1000
Maximum number of iterations to search for the hyperplane. If no valid
hyperplane is found within this limit, a ValueError is raised.
Returns
-------
torch.Tensor
Hyperplane tensor with shape (batch_size, head_dim) scaled by -1e5 / ||Y||²
to ensure that exp(<X, Y>) ≈ 0 for all queries in X.
Raises
------
ValueError
If no valid hyperplane is found within max_iter iterations.
"""
Y = X.mean(1) # this initialization is enough for most cases
for _ in range(max_iter):
mask = torch.bmm(X, Y.unsqueeze(-1)) <= 0
if not mask.any():
return -1e5 * Y / Y.norm(dim=-1, keepdim=True) ** 2
Y += (X * mask).sum(1) / mask.sum(1).clamp(min=1)
raise ValueError("Could not find fake keys such that for every query q, exp(<q, k>) = 0")
def attention_patch(func):
"""
Decorator to update the keys before the attention computation at the indices provided in module.masked_key_indices
The keys are updated with a fake key k such that exp(<q, k>) = 0 to fake head-wise compression
This solution is not optimal as it does not reduce peak memory and slightly increases runtime
Parameters
----------
func : callable
The original attention function to be patched. Should accept parameters
(module, query, key, value, attention_mask, dropout, **kwargs).
Returns
-------
callable
The wrapped attention function that supports head-wise key masking.
"""
def wrapper(module, query, key, value, attention_mask, dropout, **kwargs):
if query.shape[2] == key.shape[2]:
# Prefilling
module.masked_key_indices = None
elif getattr(module, "masked_key_indices", None) is not None:
# Decoding: build fake keys k s.t. exp(<q, k>) = 0
bsz, num_heads, seq_len, head_dim = query.shape
num_key_value_heads = key.shape[1]
num_groups = num_heads // num_key_value_heads
# Build a fake key k per key group such that for every query q, exp(<q, k>) = 0
q = query.view(bsz, num_key_value_heads, num_groups, seq_len, head_dim)
q = q.reshape(bsz * num_key_value_heads, num_groups * seq_len, head_dim)
k = search_hyperplane(q)
k = k.view(bsz, num_key_value_heads, head_dim)
# At indices, update the keys to the fake keys
batch_indices, head_indices, seq_indices = module.masked_key_indices
key[batch_indices, head_indices, seq_indices] = k[batch_indices, head_indices]
# see https://github.com/NVIDIA/kvpress/pull/115#issuecomment-3183785597
# cu_seq_lens_k are only in kwargs if model.generate is used.
if "cu_seq_lens_k" in kwargs:
kwargs["cu_seq_lens_k"][-1] = key.shape[-2]
return func(module, query, key, value, attention_mask, dropout, **kwargs)
return wrapper
def patch_attention_functions():
"""
Apply attention patching to all transformer attention functions.
This function automatically patches all attention functions registered in
transformers' ALL_ATTENTION_FUNCTIONS to support head-wise key masking.
It enables KVPress compression methods that require head-specific masking
(like AdaKV) to work correctly during text generation.
The patching is applied globally and affects all transformer models loaded
after this function is called. It's automatically called when importing
kvpress to ensure compatibility with head-wise compression methods.
Notes
-----
This function modifies the global attention functions in the transformers
library. The modifications do not affect models that don't use head-wise compression (i.e. don't have
module.masked_key_indices).
"""
for name, func in ALL_ATTENTION_FUNCTIONS.items():
ALL_ATTENTION_FUNCTIONS[name] = attention_patch(func)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment