test_block_press.py 2.49 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import torch
import torch.nn as nn
from transformers import DynamicCache

from kvpress.presses.block_press import BlockPress
from kvpress.presses.scorer_press import ScorerPress
from tests.fixtures import unit_test_model  # noqa: F401


@dataclass
class HiddenStatesPress(ScorerPress):  # dummy press using hidden states

    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs,
    ) -> torch.Tensor:
        return hidden_states.mean(-1).unsqueeze(1).expand_as(keys.norm(dim=-1))


def test_block_press_is_streaming_top_k(unit_test_model):  # noqa: F811
    """
    Test that BlockPress correctly applies the compression ratio and keeps the output consistent.
    """
    press = HiddenStatesPress(compression_ratio=0.5)
    generator = torch.Generator().manual_seed(0)
    input_ids = torch.randint(0, 1024, (1, 256), generator=generator).to(unit_test_model.device)
    keys_hash = []
    values_hash = []

    for block_size in [2, 4, 8, 128, 256]:
        composed_press = BlockPress(press=press, block_size=block_size)
        with composed_press(unit_test_model):
            cache = DynamicCache()
            unit_test_model(input_ids, past_key_values=cache).past_key_values
            assert cache.get_seq_length() == 128
            keys = torch.cat([cache.layers[layer_idx].keys for layer_idx in range(len(cache.layers))])
            values = torch.cat([cache.layers[layer_idx].values for layer_idx in range(len(cache.layers))])
            keys_hash.append(keys.sum().item())
            values_hash.append(values.sum().item())

    with press(unit_test_model):
        cache = DynamicCache()
        unit_test_model(input_ids, past_key_values=cache).past_key_values
        assert cache.get_seq_length() == 128
        keys = torch.cat([cache.layers[layer_idx].keys for layer_idx in range(len(cache.layers))])
        values = torch.cat([cache.layers[layer_idx].values for layer_idx in range(len(cache.layers))])
        keys_hash.append(keys.sum().item())
        values_hash.append(values.sum().item())

    keys_tensor = torch.tensor(keys_hash)
    values_tensor = torch.tensor(values_hash)
    assert torch.allclose(keys_tensor, keys_tensor[-1])
    assert torch.allclose(values_tensor, values_tensor[-1])