test_head_compression.py 4.31 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from transformers import DynamicCache

from kvpress import AdaKVPress, CriticalAdaKVPress, DMSPress, KnormPress, KVzipPress, RandomPress
from tests.fixtures import kv_press_unit_test_pipeline, unit_test_model  # noqa: F401


def compute_masked_percentage(module, batch_size, num_key_value_heads, seq_len):
    """
    Compute the percentage of masked indices from module.masked_key_indices.
    """
    if module.masked_key_indices is None:
        return 0.0

    batch_indices, head_indices, seq_indices = module.masked_key_indices
    num_masked = len(batch_indices)
    total_positions = batch_size * num_key_value_heads * seq_len
    masked_percentage = num_masked / total_positions
    return masked_percentage


@pytest.mark.parametrize("wrapper_press", [AdaKVPress, CriticalAdaKVPress])
@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8])
def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ratio):  # noqa: F811
    p = KnormPress(compression_ratio=compression_ratio)
    press = wrapper_press(press=p)
    with press(unit_test_model):
        input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device)
        unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values

    assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
    headwise_compression_ratio = 0.0
    for layer in unit_test_model.model.layers:
        cr = compute_masked_percentage(layer.self_attn, 1, unit_test_model.config.num_key_value_heads, 128)
        headwise_compression_ratio += cr
    cumulative_compression_ratio = headwise_compression_ratio / len(unit_test_model.model.layers)
    assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2  # tolerate small differences


# Only for KVzipPress, since it's the only non-wrapper press with head compression (apart from Duo)
@pytest.mark.parametrize("press", [KVzipPress])
@pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8])
@pytest.mark.parametrize("layerwise", [True, False])
def test_head_compression(unit_test_model, press, compression_ratio, layerwise):  # noqa: F811
    press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise)
    with press(unit_test_model):
        input_ids = torch.randint(0, 1024, (1, 128)).to(unit_test_model.device)
        unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values

    assert unit_test_model.model.layers[0].self_attn.masked_key_indices is not None
    headwise_compression_ratio = 0.0
    for layer in unit_test_model.model.layers:
        cr = compute_masked_percentage(layer.self_attn, 1, unit_test_model.config.num_key_value_heads, 128)
        headwise_compression_ratio += cr
    cumulative_compression_ratio = headwise_compression_ratio / len(unit_test_model.model.layers)
    assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2  # tolerate small differences


def test_dms_press_compression_ratio(kv_press_unit_test_pipeline):  # noqa: F811
    """Test that DMSPress.compression_ratio matches the actual masked percentage."""
    press = DMSPress(
        press=RandomPress(),
        threshold=0.5,
        sliding_window_size=0,
        decoding=True,
    )

    prompt = "What is the best KV cache compression library in the world ?"
    max_new_tokens = 10
    kv_press_unit_test_pipeline(prompt, press=press, max_new_tokens=max_new_tokens)

    model = kv_press_unit_test_pipeline.model
    num_key_value_heads = model.config.num_key_value_heads

    # Compute seq_len by reusing the pipeline's preprocess method
    preprocessed = kv_press_unit_test_pipeline.preprocess(prompt, [""], answer_prefix="", max_context_length=10000)
    seq_len = preprocessed["context_ids"].shape[1] + preprocessed["questions_ids"][0].shape[1] + max_new_tokens - 1

    # Compute compression ratio from masked indices
    headwise_compression_ratio = 0.0
    for layer in model.model.layers:
        cr = compute_masked_percentage(layer.self_attn, 1, num_key_value_heads, seq_len)
        headwise_compression_ratio += cr
    cumulative_compression_ratio = headwise_compression_ratio / len(model.model.layers)

    assert cumulative_compression_ratio == press.compression_ratio