Commit c16d506e authored by chenzk's avatar chenzk
Browse files

v1.0

parents
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.presses.duo_attention_press import PATTERNS_DICT, DuoAttentionPress
def test_load_attention_pattern():
for model_name in PATTERNS_DICT:
model = type("model", (), {"config": type("config", (), {"name_or_path": model_name})})()
DuoAttentionPress.load_attention_pattern(model)
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStats, ExpectedAttentionStatsPress
def test_load_stats():
for stats_id in ExpectedAttentionStatsPress.available_stats():
ExpectedAttentionStats.from_pretrained(stats_id)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from kvpress import FinchPress
from tests.fixtures import unit_test_model # noqa: F401
def test_finch_press(unit_test_model): # noqa: F811
for press in [
FinchPress(0.5),
FinchPress(0.5, rerotate_keys=False),
FinchPress(0.5, normalize_scores=False),
FinchPress(0.2, chunk_length=5),
]:
press.delimiter_token_id = unit_test_model.config.eos_token_id
with press(unit_test_model):
input_ids = torch.arange(10, 20).to(unit_test_model.device)
input_ids[8] = press.delimiter_token_id
unit_test_model(input_ids.unsqueeze(0))
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from transformers import AutoTokenizer
from transformers.utils import is_flash_attn_2_available
from kvpress import KnormPress
from tests.fixtures import kv_press_qwen3_flash_attn_pipeline # noqa: F401
class TestFlashAttention:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
def test_fa_works(self, kv_press_qwen3_flash_attn_pipeline): # noqa: F811
# test if fa2 runs, see https://github.com/huggingface/transformers/releases/tag/v4.55.2
# and https://github.com/NVIDIA/kvpress/pull/115
model = kv_press_qwen3_flash_attn_pipeline.model
tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
inputs = tok("Hello, how are you? bla bla how are you? this is some text lala ddd", return_tensors="pt").to(
model.device
)
with KnormPress(0.8)(model):
model.generate(**inputs, max_new_tokens=10, do_sample=False)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from transformers import DynamicCache
from kvpress import AdaKVPress, CriticalAdaKVPress, DMSPress, KnormPress, KVzipPress, RandomPress
from tests.fixtures import kv_press_unit_test_pipeline, unit_test_model # noqa: F401
def compute_masked_percentage(module, batch_size, num_key_value_heads, seq_len):
"""
Compute the percentage of masked indices from module.masked_key_indices.
"""
if module.masked_key_indices is None:
return 0.0
batch_indices, head_indices, seq_indices = module.masked_key_indices
num_masked = len(batch_indices)
total_positions = batch_size * num_key_value_heads * seq_len
masked_percentage = num_masked / total_positions
return masked_percentage
@pytest.mark.parametrize("wrapper_press", [AdaKVPress, CriticalAdaKVPress])
@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8])
def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ratio): # noqa: F811
p = KnormPress(compression_ratio=compression_ratio)
press = wrapper_press(press=p)
with press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device)
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
headwise_compression_ratio = 0.0
for layer in unit_test_model.model.layers:
cr = compute_masked_percentage(layer.self_attn, 1, unit_test_model.config.num_key_value_heads, 128)
headwise_compression_ratio += cr
cumulative_compression_ratio = headwise_compression_ratio / len(unit_test_model.model.layers)
assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences
# Only for KVzipPress, since it's the only non-wrapper press with head compression (apart from Duo)
@pytest.mark.parametrize("press", [KVzipPress])
@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8])
@pytest.mark.parametrize("layerwise", [True, False])
def test_head_compression(unit_test_model, press, compression_ratio, layerwise): # noqa: F811
press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise)
with press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device)
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
headwise_compression_ratio = 0.0
for layer in unit_test_model.model.layers:
cr = compute_masked_percentage(layer.self_attn, 1, unit_test_model.config.num_key_value_heads, 128)
headwise_compression_ratio += cr
cumulative_compression_ratio = headwise_compression_ratio / len(unit_test_model.model.layers)
assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences
def test_dms_press_compression_ratio(kv_press_unit_test_pipeline): # noqa: F811
"""Test that DMSPress.compression_ratio matches the actual masked percentage."""
press = DMSPress(
press=RandomPress(),
threshold=0.5,
sliding_window_size=0,
decoding=True,
)
prompt = "What is the best KV cache compression library in the world ?"
max_new_tokens = 10
kv_press_unit_test_pipeline(prompt, press=press, max_new_tokens=max_new_tokens)
model = kv_press_unit_test_pipeline.model
num_key_value_heads = model.config.num_key_value_heads
# Compute seq_len by reusing the pipeline's preprocess method
preprocessed = kv_press_unit_test_pipeline.preprocess(prompt, [""], answer_prefix="", max_context_length=10000)
seq_len = preprocessed["context_ids"].shape[1] + preprocessed["questions_ids"][0].shape[1] + max_new_tokens - 1
# Compute compression ratio from masked indices
headwise_compression_ratio = 0.0
for layer in model.model.layers:
cr = compute_masked_percentage(layer.self_attn, 1, num_key_value_heads, seq_len)
headwise_compression_ratio += cr
cumulative_compression_ratio = headwise_compression_ratio / len(model.model.layers)
assert cumulative_compression_ratio == press.compression_ratio
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Extended to test both the *default* and the *YaRN-scaled* rotary-embedding
# variants with the smallest possible code changes.
import inspect
from copy import deepcopy
from dataclasses import dataclass
import pytest
import torch
from torch import nn
from transformers import Gemma3ForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaRotaryEmbedding, rotate_half
from kvpress import KeyRerotationPress, ScorerPress
from tests.fixtures import unit_test_model # noqa: F401
@pytest.mark.parametrize("rope_variant", ["default", "yarn"])
@pytest.mark.parametrize("precision", ["full", "half"])
def test_rerotate_keys_is_matches_reference_implementation(
unit_test_model: LlamaForCausalLM, # noqa: F811
rope_variant,
precision,
):
"""
Compare KeyRerotationPress' rerotation of keys with the reference
implementation.
Reference path:
1. keys = W_k * hidden_states
2. keys_pruned = prune(keys)
3. keys = RoPE(keys_pruned)
Press path:
1. keys = W_k * hidden_states
2. keys = RoPE(keys)
3. keys_pruned = KeyRerotationPress.rerotate_keys(...)
"""
if rope_variant == "yarn":
layer0 = unit_test_model.model.layers[0]
cfg = deepcopy(layer0.self_attn.config)
cfg.rope_scaling = {
"factor": 4.0,
"original_max_position_embeddings": 32768,
"rope_type": "yarn",
}
cfg.max_position_embeddings = 131072
cfg.rope_theta = 500000.0
try:
unit_test_model.model.rotary_emb = LlamaRotaryEmbedding(cfg, device=unit_test_model.device)
except KeyError:
pytest.skip("YaRN rotary-embedding not available in this transformers version.")
for layer in unit_test_model.model.layers:
if isinstance(unit_test_model, Gemma3ForCausalLM) and layer.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
continue
layer.self_attn.rotary_emb = unit_test_model.model.rotary_emb
if precision == "half" and torch.cuda.is_available():
unit_test_model = unit_test_model.cuda().half()
elif precision == "half":
pytest.skip("Half-precision test skipped because CUDA is not available.")
elif precision == "full":
unit_test_model = unit_test_model.float()
original_press = RandomPressStoreIndices(compression_ratio=0.5)
key_rerotation_press = KeyRerotationPress(press=original_press)
with key_rerotation_press(unit_test_model):
module = unit_test_model.model.layers[0].self_attn
hidden_states = torch.randn(
8, 64, module.config.hidden_size, device=unit_test_model.device, dtype=unit_test_model.dtype
)
keys = get_keys_with_rope(module, hidden_states)
values = torch.randn_like(keys)
# Press result
keys_compressed, _ = key_rerotation_press.compress(
module,
hidden_states,
keys,
values,
attentions=None,
kwargs={},
)
indices = original_press.indices
keys_compressed_ref = compute_rerotated_keys_comparison_implementation(module, hidden_states, indices)
assert torch.allclose(keys_compressed, keys_compressed_ref, atol=1e-6 if precision == "full" else 1e-3)
def get_keys_with_rope(module, hidden_states):
# Compute keys with RoPE
keys = get_keys_without_pos_embedding(module, hidden_states)
cos, sin = get_rope_embeddings(module, keys)
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
return keys
@dataclass
class RandomPressStoreIndices(ScorerPress):
compression_ratio: float = 0.0
seed: int = 0
def __post_init__(self):
self.indices = None
super().__post_init__()
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
torch.manual_seed(self.seed)
scores = torch.rand(*keys.shape[:-1]).to(keys.device, keys.dtype)
# Get indices of KV pairs with the lowest scores
q_len = hidden_states.shape[1]
n_kept = int(q_len * (1 - self.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = torch.sort(indices, dim=2).values
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
self.indices = indices
return scores
def compute_rerotated_keys_comparison_implementation(module: LlamaAttention, hidden_states, indices):
"""
Computes the rerotated keys for the given indices.
1. keys = W_k * hidden_states
2. keys_pruned = prune(keys)
3. keys = RoPE(keys_pruned)
"""
# 1.
keys = get_keys_without_pos_embedding(module, hidden_states)
# 2.
keys = keys.gather(2, indices).contiguous()
# 3.
cos, sin = get_rope_embeddings(module, keys)
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
return keys
def get_keys_without_pos_embedding(module, hidden_states):
key_states = module.k_proj(hidden_states)
key_states = key_states.view(
key_states.shape[0], key_states.shape[1], module.config.num_key_value_heads, module.head_dim
).transpose(1, 2)
return key_states
def get_rope_embeddings(module, x):
length = x.shape[2]
# rotary_emb function only needs .device and .dtype, so we can plug in any tensor regardless of shape
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(length).unsqueeze(0).to(x.device)
cos, sin = module.rotary_emb(x, position_ids)
else:
cos, sin = module.rotary_emb(x, length)
return cos, sin
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
import torch
from torch import nn
from transformers import DynamicCache
from kvpress import (
AdaKVPress,
ChunkKVPress,
ChunkPress,
ComposedPress,
CriticalAdaKVPress,
CriticalKVPress,
DMSPress,
FastKVzipPress,
KeyRerotationPress,
KnormPress,
KVzipPress,
ObservedAttentionPress,
ScorerPress,
SnapKVPress,
ThinKPress,
)
from tests.default_presses import default_presses
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
def test_composed_press(unit_test_model): # noqa: F811
press1 = KnormPress(compression_ratio=0.5)
press2 = ThinKPress(key_channel_compression_ratio=0.5, window_size=2)
composed_press = ComposedPress([press1, press2])
with composed_press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
def test_chunk_press(unit_test_model): # noqa: F811
press = KnormPress(compression_ratio=0.5)
for chunk_length in [2, 4, 8, 128]:
composed_press = ChunkPress(press=press, chunk_length=chunk_length)
with composed_press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 256), device=unit_test_model.device)
cache = DynamicCache()
unit_test_model(input_ids, past_key_values=cache).past_key_values
assert cache.get_seq_length() == 128
def test_chunkkv_press(unit_test_model): # noqa: F811
press = SnapKVPress(compression_ratio=0.5)
for chunk_length in [2, 4, 8, 128]:
composed_press = ChunkKVPress(press=press, chunk_length=chunk_length)
with composed_press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 256), device=unit_test_model.device)
cache = DynamicCache()
unit_test_model(input_ids, past_key_values=cache).past_key_values
assert cache.get_seq_length() == 128
@pytest.mark.parametrize("press_dict", default_presses)
@pytest.mark.parametrize(
"wrapper_press",
[
None,
ComposedPress,
KeyRerotationPress,
AdaKVPress,
ChunkPress,
CriticalKVPress,
CriticalAdaKVPress,
DMSPress,
],
)
def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811
cls = press_dict["cls"]
for kwargs in press_dict["kwargs"]:
press = cls(**kwargs)
if wrapper_press is not None:
if hasattr(press, "post_init_from_model"):
press.post_init_from_model(unit_test_model)
if issubclass(wrapper_press, ComposedPress):
if isinstance(press, (KVzipPress, FastKVzipPress)):
# KVzipPress and FastKVzipPress are currently not compatible with ComposedPress
return
press = ComposedPress(presses=[press])
elif not isinstance(press, ScorerPress): # remaining wrapper presses only support ScorerPress
return
elif issubclass(wrapper_press, (KeyRerotationPress, AdaKVPress, CriticalKVPress, CriticalAdaKVPress)):
press = wrapper_press(press=press)
elif issubclass(wrapper_press, ChunkPress):
press = ChunkPress(press=press, chunk_length=24)
elif issubclass(wrapper_press, DMSPress):
press = DMSPress(press=press, threshold=-0.5, sliding_window_size=32)
# TODO: Handle post_init_from_model differently
if hasattr(press, "post_init_from_model"):
press.post_init_from_model(unit_test_model)
with press(unit_test_model):
input_ids = torch.randint(0, 1024, (1, 128), device=unit_test_model.device)
unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
# Check that the press has a compression_ratio attribute
assert hasattr(press, "compression_ratio")
def test_presses_run_observed_attention(unit_test_model_output_attention): # noqa: F811
for cls in [ObservedAttentionPress]:
for compresion_ratio in [0.2, 0.8]:
press = cls(compression_ratio=compresion_ratio)
with press(unit_test_model_output_attention):
input_ids = unit_test_model_output_attention.dummy_inputs["input_ids"].to(
unit_test_model_output_attention.device
)
unit_test_model_output_attention(input_ids, past_key_values=DynamicCache()).past_key_values
@dataclass
class StoreKnormPress(ScorerPress):
def __post_init__(self):
self.scores = []
def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
scores = -keys.norm(dim=-1)
self.scores.append(scores)
return scores
@torch.no_grad()
def test_presses_keep_highest_score(unit_test_model): # noqa: F811
"""
Test that kept keys are those with the highest score
"""
for compresion_ratio in [0.0, 0.2, 0.4, 0.6, 0.8]:
press = StoreKnormPress(compression_ratio=compresion_ratio)
with press(unit_test_model):
input_ids = torch.randint(0, 3_000, (5, 256), device=unit_test_model.device)
past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
keys = [layer.keys for layer in past_key_values.layers]
for scores, key in zip(press.scores, keys):
max_scores = -key.norm(dim=-1)
for batch_idx in range(scores.shape[0]):
for head_idx in range(scores.shape[1]):
assert torch.allclose(
scores[batch_idx, head_idx].sort().values[-max_scores.shape[-1] :],
max_scores[batch_idx, head_idx].sort().values,
)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch.nn as nn
from kvpress.presses.pyramidkv_press import PyramidKVPress
class MockConfig:
def __init__(self, num_hidden_layers):
self.num_hidden_layers = num_hidden_layers
class MockModule(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
def scorer_press_layer_budget(q_len, compression_ratio):
return round(q_len * (1 - compression_ratio))
@pytest.mark.parametrize("layer_budget_func", ["pyramidkv_press_layer_budget", "scorer_press_layer_budget"])
@pytest.mark.parametrize("num_hidden_layers", [32, 64, 128])
@pytest.mark.parametrize("compression_ratio", [0.1, 0.25, 0.3, 0.5, 0.6, 0.75, 0.8])
@pytest.mark.parametrize("q_len", [1024, 2787, 4096, 6591, 8192])
def test_mean_layer_budget(layer_budget_func, num_hidden_layers, compression_ratio, q_len):
total_n_kept = 0
if layer_budget_func == "pyramidkv_press_layer_budget":
config = MockConfig(num_hidden_layers)
press = PyramidKVPress()
press.compression_ratio = compression_ratio
for layer_idx in range(num_hidden_layers):
if layer_budget_func == "pyramidkv_press_layer_budget":
module = MockModule(config, layer_idx)
n_kept = press.get_layer_budget(module, q_len)
elif layer_budget_func == "scorer_press_layer_budget":
n_kept = scorer_press_layer_budget(q_len, compression_ratio)
else:
raise ValueError(f"Unsupported layer_budget_func: {layer_budget_func}")
total_n_kept += n_kept
mean_n_kept = total_n_kept / num_hidden_layers
assert mean_n_kept == pytest.approx(q_len * (1 - compression_ratio), rel=1e-3)
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress.presses.qfilter_press import QFilterPress
def test_load_qfilters():
for model_name in QFilterPress.available_qfilters():
QFilterPress.load_q_filters(model_name)
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from kvpress.attention_patch import search_hyperplane
def test_search_hyperplane():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
bsz, seq_len, head_dim = 50, 500, 128
X = torch.rand(bsz, seq_len, head_dim, device=device)
Y = search_hyperplane(X)
assert torch.exp(torch.bmm(X, Y.unsqueeze(-1))).max() == 0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Test script to verify that DecodingPress actually compresses during decoding.
"""
import logging
import pytest
import torch
from transformers import DynamicCache, pipeline
from kvpress import (
CompactorPress,
DecodingPress,
KnormPress,
KVzapPress,
LeverageScorePress,
NonCausalAttnPress,
PrefillDecodingPress,
PyramidKVPress,
ScorerPress,
)
from tests.default_presses import default_presses
logger = logging.getLogger(__name__)
@pytest.mark.parametrize("token_buffer_size", [32, 64, 128])
def test_decoding_compression(token_buffer_size):
"""Test that DecodingPress compresses the cache during decoding."""
# Initialize pipeline with a small model
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
# Create a DecodingPress with KnormPress
press = DecodingPress(
base_press=KnormPress(compression_ratio=0.5), # Remove 50% of tokens
compression_interval=4, # Compress every 4 tokens
target_size=token_buffer_size,
)
# Create cache
cache = DynamicCache()
# Test context and question
context = "The quick brown fox jumps over the lazy dog. " * 10 # Repeat for longer context
question = "What animal jumps over the dog?"
# Run pipeline
pipe(context, question=question, press=press, cache=cache, max_new_tokens=20)
# Assert that all layers have the expected cache size
for layer_idx, cache_layer in enumerate(cache.layers):
layer_seq_len = cache_layer.keys.shape[2]
# Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
max_expected_size = token_buffer_size + press.compression_interval - 1
assert layer_seq_len <= max_expected_size, (
f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} "
f"and {max_expected_size}, but got {layer_seq_len}"
)
def test_prefill_decoding_press_calls_both_phases():
"""Test that PrefillDecodingPress calls both prefilling and decoding presses."""
# Initialize pipeline
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
# Create PrefillDecodingPress with both presses
combined_press = PrefillDecodingPress(
prefilling_press=KnormPress(compression_ratio=0.6), # Compress to 60% during prefill
decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48),
)
# Test context and question
context = "The quick brown fox jumps over the lazy dog. " * 12 # Longer context
question = "What animal jumps over the dog?"
# Run pipeline
cache = DynamicCache()
pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=15)
# Check that cache was compressed during both phases
# Final cache should be compressed to decoding press target size
for layer_idx, cache_layer in enumerate(cache.layers):
layer_seq_len = cache_layer.keys.shape[2]
# Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
target_size = 48 # token_buffer_size from decoding press
compression_steps = 3 # from the decoding press configuration
max_expected_size = target_size + compression_steps - 1
assert target_size <= layer_seq_len <= max_expected_size, (
f"Layer {layer_idx}: Expected final cache size to be between {target_size} "
f"and {max_expected_size} (decoding target), but got {layer_seq_len}"
)
def test_decoding_press_without_prefill():
"""Test that DecodingPress works correctly when used standalone (no prefill compression)."""
# Initialize pipeline
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
# Create DecodingPress only
decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64)
# Test context and question
context = "The quick brown fox jumps over the lazy dog. " * 8
question = "What animal jumps over the dog?"
# Run pipeline
cache = DynamicCache()
pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25)
# Check that cache was compressed during decoding
for layer_idx, cache_layer in enumerate(cache.layers):
layer_seq_len = cache_layer.keys.shape[2]
# Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
target_size = 64
compression_steps = 5 # from the decoding press configuration
max_expected_size = target_size + compression_steps - 1
assert target_size <= layer_seq_len <= max_expected_size, (
f"Layer {layer_idx}: Expected cache size to be between {target_size} "
f"and {max_expected_size}, but got {layer_seq_len}"
)
def test_prefill_decoding_press_decoding_only():
"""Test PrefillDecodingPress with only decoding press (no prefill compression)."""
# Initialize pipeline
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
# Create PrefillDecodingPress with only decoding press
combined_press = PrefillDecodingPress(
prefilling_press=None,
decoding_press=DecodingPress(
base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56
),
)
# Test context and question
context = "The quick brown fox jumps over the lazy dog. " * 9
question = "What animal jumps over the dog?"
# Run pipeline
cache = DynamicCache()
pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=12)
# Check that only decoding compression was applied
for layer_idx, cache_layer in enumerate(cache.layers):
layer_seq_len = cache_layer.keys.shape[2]
# Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
target_size = 56
compression_steps = 4 # from the decoding press configuration
max_expected_size = target_size + compression_steps - 1
assert target_size <= layer_seq_len <= max_expected_size, (
f"Layer {layer_idx}: Expected cache size to be between {target_size} "
f"and {max_expected_size}, but got {layer_seq_len}"
)
def test_decoding_press_equivalence():
"""Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only."""
# Set random seed for reproducibility
torch.manual_seed(42)
# Initialize pipeline
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
# Create standalone decoding press
decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52)
# Create PrefillDecodingPress with only decoding press
combined_press = PrefillDecodingPress(
prefilling_press=None,
decoding_press=DecodingPress(
base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52
),
)
# Test context and question
context = "The quick brown fox jumps over the lazy dog. " * 7
question = "What animal jumps over the dog?"
# Run with standalone decoding press
cache1 = DynamicCache()
result1 = pipe(context, question=question, press=decoding_press, cache=cache1, max_new_tokens=10)
# Run with combined press (decoding only)
cache2 = DynamicCache()
result2 = pipe(context, question=question, press=combined_press, cache=cache2, max_new_tokens=10)
# Compare cache sizes (should be identical)
for layer_idx in range(len(cache1.layers)):
cache1_size = cache1.layers[layer_idx].keys.shape[2]
cache2_size = cache2.layers[layer_idx].keys.shape[2]
assert cache1_size == cache2_size, (
f"Layer {layer_idx}: Standalone decoding cache size {cache1_size} != "
f"combined press cache size {cache2_size}"
)
# Compare generated text results (should be identical)
assert result1["answer"] == result2["answer"], (
f"Generated answers differ:\n"
f"Standalone decoding: '{result1['answer']}'\n"
f"Combined press: '{result2['answer']}'"
)
"""
E AttributeError: 'QFilterPress' object has no attribute 'q_filters'
E Failed: DecodingPress failed with SnapKVPress: shape '[1, 2, 2, 6]' is invalid for input of size 12
> query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)
E RuntimeError: shape '[1, 2, 2, 6]' is invalid for input of size 12
"""
@pytest.mark.parametrize("press_config", default_presses)
def test_all_presses_work_with_decoding_press(press_config):
"""Test that all default presses work as base presses for DecodingPress."""
# Initialize pipeline
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
# Get press class and use the first (easier) configuration
press_cls = press_config["cls"]
press_kwargs = press_config["kwargs"][0] # Use easier compression settings
base_press = press_cls(**press_kwargs)
if not isinstance(base_press, ScorerPress):
logger.info(f"Press {press_cls.__name__} is not a ScorerPress, skipping test")
return
if isinstance(base_press, (PyramidKVPress)):
# PyramidKVPress -> Pyramid shape, not compatible with token_buffer_size=48
logger.info(f"Press {press_cls.__name__} is not supported, skipping test")
return
if isinstance(base_press, (CompactorPress, NonCausalAttnPress, LeverageScorePress)):
# CompactorPress -> Meant for prefill scenario.
logger.info(f"Press {press_cls.__name__} is not supported, skipping test")
return
if isinstance(base_press, KVzapPress):
logger.info(f"Press {press_cls.__name__} is not compatible with DecodingPress, skipping test")
return
# Create DecodingPress with this base press
decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48)
# Test context and question
context = "The quick brown fox jumps over the lazy dog. " * 8
question = "What animal jumps over the dog?"
# Run pipeline
cache = DynamicCache()
result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15)
# Verify compression worked
assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}"
# Check that cache was compressed (allow some tolerance for rounding)
for layer_idx, cache_layer in enumerate(cache.layers):
layer_seq_len = cache_layer.keys.shape[2]
# Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
target_size = 48
compression_steps = 3 # from the decoding press configuration
max_expected_size = target_size + compression_steps - 1
assert (
target_size <= layer_seq_len <= max_expected_size
), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]" # noqa: E501
def test_compression_actually_reduces_memory():
"""Test that compression actually reduces memory usage compared to no compression."""
pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")
context = "The quick brown fox jumps over the lazy dog. " * 15 # Long context
question = "What animal jumps over the dog?"
# Run without compression
cache_uncompressed = DynamicCache()
result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25)
# Run with compression
press = DecodingPress(
base_press=KnormPress(compression_ratio=0.3), # Aggressive compression
compression_interval=3,
target_size=40,
)
cache_compressed = DynamicCache()
result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25)
# Calculate memory usage (approximate)
uncompressed_memory = sum(
(cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size()
for cache_layer in cache_uncompressed.layers
)
compressed_memory = sum(
(cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size()
for cache_layer in cache_compressed.layers
)
# Compression should significantly reduce memory usage
compression_ratio = compressed_memory / uncompressed_memory
assert compression_ratio < 0.6, (
f"Expected compression ratio < 0.6, but got {compression_ratio:.3f} "
f"(compressed: {compressed_memory} bytes, uncompressed: {uncompressed_memory} bytes)"
)
# Both should still generate reasonable answers
assert len(result_uncompressed["answer"]) > 0
assert len(result_compressed["answer"]) > 0
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from kvpress import KnormPress
from tests.fixtures import kv_press_unit_test_pipeline # noqa: F401
def test_generate(kv_press_unit_test_pipeline): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
press = KnormPress(compression_ratio=0.4)
# Answer with pipeline
pipe_answer = kv_press_unit_test_pipeline(context, press=press, max_new_tokens=10)["answer"]
# Answer with model.generate
context += "\n" # kv press pipeline automatically adds a newline if no chat template
model = kv_press_unit_test_pipeline.model
tokenizer = kv_press_unit_test_pipeline.tokenizer
with press(model):
inputs = tokenizer(context, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False)
generate_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
generate_answer = generate_answer[len(context) :]
assert pipe_answer == generate_answer
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from transformers import DynamicCache
from kvpress.presses.knorm_press import KnormPress
from kvpress.presses.per_layer_compression_press import PerLayerCompressionPress
from tests.fixtures import unit_test_model # noqa: F401
def test_per_layer_compression_press(unit_test_model): # noqa: F811
press = PerLayerCompressionPress(compression_ratios=[0.1, 1], press=KnormPress())
with press(unit_test_model):
input_ids = torch.randint(0, 3_000, (5, 256), device=unit_test_model.device)
past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
assert past_key_values.layers[0].keys.shape == torch.Size([5, 2, 230, 6])
assert past_key_values.layers[1].keys.shape == torch.Size([5, 2, 0, 6])
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import pytest
import torch
from transformers import AutoTokenizer, DynamicCache, QuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
from kvpress import ExpectedAttentionPress
from kvpress.pipeline import KVPressTextGenerationPipeline
from tests.fixtures import danube_500m_model # noqa: F401
from tests.fixtures import kv_press_danube_pipeline # noqa: F401
from tests.fixtures import unit_test_model # noqa: F401
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_unit_test_pipeline # noqa: F401
def test_pipeline(kv_press_unit_test_pipeline, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"]
assert len(answers) == 1
assert isinstance(answers[0], str)
messages = [record.message for record in caplog.records]
assert "Context Length: 23" in messages, messages
assert "Compressed Context Length: 13" in messages, messages
def test_pipeline_with_cache(kv_press_unit_test_pipeline): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
cache = DynamicCache()
answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
assert len(answers) == 1
assert isinstance(answers[0], str)
class TestPipelineFA2:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
@pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
@pytest.mark.parametrize("compression_ratio", [0.0, 0.2])
@pytest.mark.xfail(reason="Known issue (https://github.com/huggingface/transformers/issues/42550)", strict=False)
def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
questions = ["Repeat the last sentence"]
press = ExpectedAttentionPress(compression_ratio=compression_ratio)
cache = DynamicCache()
answers = kv_press_llama3_2_flash_attn_pipeline(
context, questions=questions, press=press, cache=cache, max_new_tokens=6
)["answers"]
assert len(answers) == 1
assert isinstance(answers[0], str)
kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa")
press = ExpectedAttentionPress(compression_ratio=compression_ratio)
cache = DynamicCache()
answers_sdpa = kv_press_llama3_2_flash_attn_pipeline(
context, questions=questions, press=press, cache=cache, max_new_tokens=6
)["answers"]
kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2")
assert (
answers_sdpa[0] == answers[0]
), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}"
assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}."
@pytest.mark.parametrize("question", ["When was this article written?", ""])
def test_pipeline_single_or_no_question(kv_press_unit_test_pipeline, question, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
context = "This is a test article. It was written on 2022-01-01."
press = ExpectedAttentionPress(compression_ratio=0.4)
answer = kv_press_unit_test_pipeline(context, question=question, press=press)["answer"]
assert isinstance(answer, str)
messages = [record.message for record in caplog.records]
assert "Context Length: 23" in messages, messages
assert "Compressed Context Length: 13" in messages, messages
def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog): # noqa: F811
context = "This is a test article. It was written on 2022-01-01."
question = "When was this article written?"
kv_press_unit_test_pipeline(context, question=question)
def test_pipeline_answer_is_correct(danube_500m_model, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
answers = generate_answer(danube_500m_model)
for answer in answers:
assert answer == "This article was written on January 1, 2022."
messages = [record.message for record in caplog.records]
assert "Context Length: 28" in messages
assert "Compressed Context Length: 16" in messages
@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available")
def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
cache = QuantizedCache(backend="quanto", config=kv_press_danube_pipeline.model.config, nbits=4)
answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
assert len(answers) == 1
assert isinstance(answers[0], str)
for answer in answers:
assert answer == "This article was written on January 1, 2022."
messages = [record.message for record in caplog.records]
assert "Context Length: 28" in messages
assert "Compressed Context Length: 16" in messages
def test_pipeline_compresses_context(unit_test_model, caplog): # noqa: F811
with caplog.at_level(logging.DEBUG):
answers = generate_answer(unit_test_model)
assert len(answers) == 2
assert isinstance(answers[0], str)
messages = [record.message for record in caplog.records]
assert "Context Length: 23" in messages, messages
assert "Compressed Context Length: 13" in messages, messages
@torch.no_grad()
def test_pipeline_context_cache_is_invariant(unit_test_model): # noqa: F811
model = unit_test_model
questions = ["When was this article written?"]
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
device = model.device
compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)
input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)
seq_len = 256
past_key_values: DynamicCache = model(
input_ids=torch.randint(0, 1000, (1, seq_len), device=device), past_key_values=DynamicCache()
).past_key_values
assert past_key_values.get_seq_length() == seq_len
keys = [layer.keys.clone() for layer in past_key_values.layers]
values = [layer.values.clone() for layer in past_key_values.layers]
cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))]
compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10)
compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths)
assert past_key_values.get_seq_length() == seq_len
assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)])
assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)])
def generate_answer(model):
device = model.device
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?", "When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)(
context, questions=questions, press=press
)["answers"]
return answers
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from transformers import DynamicCache
from kvpress import KnormPress
from tests.fixtures import unit_test_model # noqa: F401
def test_context_manager_adds_and_removes_hook(unit_test_model): # noqa: F811
press = KnormPress(compression_ratio=0.2)
with press(unit_test_model):
for layer in unit_test_model.model.layers:
assert len(layer.self_attn._forward_hooks) == 1
for layer in unit_test_model.model.layers:
assert len(layer._forward_hooks) == 0
def test_context_manager_applies_compression(unit_test_model): # noqa: F811
press = KnormPress(compression_ratio=0.2)
with press(unit_test_model):
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
seq_len = input_ids.shape[-1]
for layer in past_key_values.layers:
assert layer.keys.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
assert layer.values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
for layer in past_key_values.layers:
assert layer.keys.shape[2] == seq_len == past_key_values.get_seq_length()
assert layer.values.shape[2] == seq_len == past_key_values.get_seq_length()
# 模型唯一标识
modelCode=2171
# 模型名称
modelName=kvpress-SnapKV-Qwen3-8B_pytorch
# 模型描述
modelDescription=SnapKV采用近期query与key点积再池化的topk scores进行剪枝。
# 运行过程
processType=推理
# 算法类别
appCategory=对话问答
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=K100AI
from transformers import pipeline, AutoTokenizer
from kvpress import SnapKVPress
model = "Qwen/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(model)
model_kwargs = {"attn_implementation": "flash_attention_2"}
# model_kwargs = {"attn_implementation": "eager"}
pipe = pipeline("kv-press-text-generation", model=model, model_kwargs=model_kwargs)
context = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
question = "美国面积多大?"
q_len = int(tokenizer(question, return_tensors="pt").input_ids.shape[1]//2)
q_len = int(0.1 * q_len) if q_len > 1024 else q_len
window_size = min(64, q_len) if q_len > 1 else 1
# print(f"using window_size: {window_size}")
press = SnapKVPress(compression_ratio=0.5, window_size=window_size)
answer = pipe(context, question=question, press=press, max_new_tokens=512)["answer"]
print("answer: ", answer)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment