test_press_call.py 1.6 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from transformers import DynamicCache

from kvpress import KnormPress
from tests.fixtures import unit_test_model  # noqa: F401


def test_context_manager_adds_and_removes_hook(unit_test_model):  # noqa: F811
    press = KnormPress(compression_ratio=0.2)

    with press(unit_test_model):
        for layer in unit_test_model.model.layers:
            assert len(layer.self_attn._forward_hooks) == 1

    for layer in unit_test_model.model.layers:
        assert len(layer._forward_hooks) == 0


def test_context_manager_applies_compression(unit_test_model):  # noqa: F811
    press = KnormPress(compression_ratio=0.2)

    with press(unit_test_model):
        input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
        past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values

    seq_len = input_ids.shape[-1]

    for layer in past_key_values.layers:
        assert layer.keys.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
        assert layer.values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()

    input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
    past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values

    for layer in past_key_values.layers:
        assert layer.keys.shape[2] == seq_len == past_key_values.get_seq_length()
        assert layer.values.shape[2] == seq_len == past_key_values.get_seq_length()