test_decoding_compression.py 13.1 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Test script to verify that DecodingPress actually compresses during decoding.
"""
import logging

import pytest
import torch
from transformers import DynamicCache, pipeline

from kvpress import (
    CompactorPress,
    DecodingPress,
    KnormPress,
    KVzapPress,
    LeverageScorePress,
    NonCausalAttnPress,
    PrefillDecodingPress,
    PyramidKVPress,
    ScorerPress,
)
from tests.default_presses import default_presses

logger = logging.getLogger(__name__)


@pytest.mark.parametrize("token_buffer_size", [32, 64, 128])
def test_decoding_compression(token_buffer_size):
    """Test that DecodingPress compresses the cache during decoding."""

    # Initialize pipeline with a small model
    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    # Create a DecodingPress with KnormPress
    press = DecodingPress(
        base_press=KnormPress(compression_ratio=0.5),  # Remove 50% of tokens
        compression_interval=4,  # Compress every 4 tokens
        target_size=token_buffer_size,
    )

    # Create cache
    cache = DynamicCache()

    # Test context and question
    context = "The quick brown fox jumps over the lazy dog. " * 10  # Repeat for longer context
    question = "What animal jumps over the dog?"

    # Run pipeline
    pipe(context, question=question, press=press, cache=cache, max_new_tokens=20)

    # Assert that all layers have the expected cache size
    for layer_idx, cache_layer in enumerate(cache.layers):
        layer_seq_len = cache_layer.keys.shape[2]
        # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
        max_expected_size = token_buffer_size + press.compression_interval - 1
        assert layer_seq_len <= max_expected_size, (
            f"Layer {layer_idx}: Expected cache sequence length to be between {token_buffer_size} "
            f"and {max_expected_size}, but got {layer_seq_len}"
        )


def test_prefill_decoding_press_calls_both_phases():
    """Test that PrefillDecodingPress calls both prefilling and decoding presses."""

    # Initialize pipeline
    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    # Create PrefillDecodingPress with both presses
    combined_press = PrefillDecodingPress(
        prefilling_press=KnormPress(compression_ratio=0.6),  # Compress to 60% during prefill
        decoding_press=DecodingPress(base_press=KnormPress(), compression_interval=3, target_size=48),
    )

    # Test context and question
    context = "The quick brown fox jumps over the lazy dog. " * 12  # Longer context
    question = "What animal jumps over the dog?"

    # Run pipeline
    cache = DynamicCache()
    pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=15)

    # Check that cache was compressed during both phases
    # Final cache should be compressed to decoding press target size
    for layer_idx, cache_layer in enumerate(cache.layers):
        layer_seq_len = cache_layer.keys.shape[2]
        # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
        target_size = 48  # token_buffer_size from decoding press
        compression_steps = 3  # from the decoding press configuration
        max_expected_size = target_size + compression_steps - 1
        assert target_size <= layer_seq_len <= max_expected_size, (
            f"Layer {layer_idx}: Expected final cache size to be between {target_size} "
            f"and {max_expected_size} (decoding target), but got {layer_seq_len}"
        )


def test_decoding_press_without_prefill():
    """Test that DecodingPress works correctly when used standalone (no prefill compression)."""

    # Initialize pipeline
    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    # Create DecodingPress only
    decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.4), compression_interval=5, target_size=64)

    # Test context and question
    context = "The quick brown fox jumps over the lazy dog. " * 8
    question = "What animal jumps over the dog?"

    # Run pipeline
    cache = DynamicCache()
    pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=25)

    # Check that cache was compressed during decoding
    for layer_idx, cache_layer in enumerate(cache.layers):
        layer_seq_len = cache_layer.keys.shape[2]
        # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
        target_size = 64
        compression_steps = 5  # from the decoding press configuration
        max_expected_size = target_size + compression_steps - 1
        assert target_size <= layer_seq_len <= max_expected_size, (
            f"Layer {layer_idx}: Expected cache size to be between {target_size} "
            f"and {max_expected_size}, but got {layer_seq_len}"
        )


def test_prefill_decoding_press_decoding_only():
    """Test PrefillDecodingPress with only decoding press (no prefill compression)."""

    # Initialize pipeline
    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    # Create PrefillDecodingPress with only decoding press
    combined_press = PrefillDecodingPress(
        prefilling_press=None,
        decoding_press=DecodingPress(
            base_press=KnormPress(compression_ratio=0.6), compression_interval=4, target_size=56
        ),
    )

    # Test context and question
    context = "The quick brown fox jumps over the lazy dog. " * 9
    question = "What animal jumps over the dog?"

    # Run pipeline
    cache = DynamicCache()
    pipe(context, question=question, press=combined_press, cache=cache, max_new_tokens=12)

    # Check that only decoding compression was applied
    for layer_idx, cache_layer in enumerate(cache.layers):
        layer_seq_len = cache_layer.keys.shape[2]
        # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
        target_size = 56
        compression_steps = 4  # from the decoding press configuration
        max_expected_size = target_size + compression_steps - 1
        assert target_size <= layer_seq_len <= max_expected_size, (
            f"Layer {layer_idx}: Expected cache size to be between {target_size} "
            f"and {max_expected_size}, but got {layer_seq_len}"
        )


def test_decoding_press_equivalence():
    """Test that DecodingPress standalone yields same result as PrefillDecodingPress with decoding only."""

    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Initialize pipeline
    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    # Create standalone decoding press
    decoding_press = DecodingPress(base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52)

    # Create PrefillDecodingPress with only decoding press
    combined_press = PrefillDecodingPress(
        prefilling_press=None,
        decoding_press=DecodingPress(
            base_press=KnormPress(compression_ratio=0.5), compression_interval=3, target_size=52
        ),
    )

    # Test context and question
    context = "The quick brown fox jumps over the lazy dog. " * 7
    question = "What animal jumps over the dog?"

    # Run with standalone decoding press
    cache1 = DynamicCache()
    result1 = pipe(context, question=question, press=decoding_press, cache=cache1, max_new_tokens=10)

    # Run with combined press (decoding only)
    cache2 = DynamicCache()
    result2 = pipe(context, question=question, press=combined_press, cache=cache2, max_new_tokens=10)

    # Compare cache sizes (should be identical)
    for layer_idx in range(len(cache1.layers)):
        cache1_size = cache1.layers[layer_idx].keys.shape[2]
        cache2_size = cache2.layers[layer_idx].keys.shape[2]
        assert cache1_size == cache2_size, (
            f"Layer {layer_idx}: Standalone decoding cache size {cache1_size} != "
            f"combined press cache size {cache2_size}"
        )

    # Compare generated text results (should be identical)
    assert result1["answer"] == result2["answer"], (
        f"Generated answers differ:\n"
        f"Standalone decoding: '{result1['answer']}'\n"
        f"Combined press: '{result2['answer']}'"
    )


"""
E       AttributeError: 'QFilterPress' object has no attribute 'q_filters'
E           Failed: DecodingPress failed with SnapKVPress: shape '[1, 2, 2, 6]' is invalid for input of size 12
>       query_states = query_states.view(bsz, window_size, num_heads, head_dim).transpose(1, 2)
E       RuntimeError: shape '[1, 2, 2, 6]' is invalid for input of size 12
"""


@pytest.mark.parametrize("press_config", default_presses)
def test_all_presses_work_with_decoding_press(press_config):
    """Test that all default presses work as base presses for DecodingPress."""

    # Initialize pipeline
    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    # Get press class and use the first (easier) configuration
    press_cls = press_config["cls"]
    press_kwargs = press_config["kwargs"][0]  # Use easier compression settings

    base_press = press_cls(**press_kwargs)
    if not isinstance(base_press, ScorerPress):
        logger.info(f"Press {press_cls.__name__} is not a ScorerPress, skipping test")
        return
    if isinstance(base_press, (PyramidKVPress)):
        # PyramidKVPress -> Pyramid shape, not compatible with token_buffer_size=48
        logger.info(f"Press {press_cls.__name__} is not supported, skipping test")
        return
    if isinstance(base_press, (CompactorPress, NonCausalAttnPress, LeverageScorePress)):
        # CompactorPress -> Meant for prefill scenario.
        logger.info(f"Press {press_cls.__name__} is not supported, skipping test")
        return

    if isinstance(base_press, KVzapPress):
        logger.info(f"Press {press_cls.__name__} is not compatible with DecodingPress, skipping test")
        return

    # Create DecodingPress with this base press
    decoding_press = DecodingPress(base_press=base_press, compression_interval=3, target_size=48)

    # Test context and question
    context = "The quick brown fox jumps over the lazy dog. " * 8
    question = "What animal jumps over the dog?"

    # Run pipeline
    cache = DynamicCache()
    result = pipe(context, question=question, press=decoding_press, cache=cache, max_new_tokens=15)

    # Verify compression worked
    assert len(result["answer"]) > 0, f"No answer generated with {press_cls.__name__}"

    # Check that cache was compressed (allow some tolerance for rounding)
    for layer_idx, cache_layer in enumerate(cache.layers):
        layer_seq_len = cache_layer.keys.shape[2]
        # Allow for compression step interval: cache can be up to compression_steps-1 tokens larger
        target_size = 48
        compression_steps = 3  # from the decoding press configuration
        max_expected_size = target_size + compression_steps - 1
        assert (
            target_size <= layer_seq_len <= max_expected_size
        ), f"{press_cls.__name__}: Layer {layer_idx} cache size {layer_seq_len} not in expected range [{target_size}-{max_expected_size}]"  # noqa: E501


def test_compression_actually_reduces_memory():
    """Test that compression actually reduces memory usage compared to no compression."""

    pipe = pipeline("kv-press-text-generation", model="MaxJeblick/llama2-0b-unit-test", device_map="auto")

    context = "The quick brown fox jumps over the lazy dog. " * 15  # Long context
    question = "What animal jumps over the dog?"

    # Run without compression
    cache_uncompressed = DynamicCache()
    result_uncompressed = pipe(context, question=question, cache=cache_uncompressed, max_new_tokens=25)

    # Run with compression
    press = DecodingPress(
        base_press=KnormPress(compression_ratio=0.3),  # Aggressive compression
        compression_interval=3,
        target_size=40,
    )
    cache_compressed = DynamicCache()
    result_compressed = pipe(context, question=question, press=press, cache=cache_compressed, max_new_tokens=25)

    # Calculate memory usage (approximate)
    uncompressed_memory = sum(
        (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size()
        for cache_layer in cache_uncompressed.layers
    )
    compressed_memory = sum(
        (cache_layer.values.numel() + cache_layer.keys.numel()) * cache_layer.keys.element_size()
        for cache_layer in cache_compressed.layers
    )

    # Compression should significantly reduce memory usage
    compression_ratio = compressed_memory / uncompressed_memory
    assert compression_ratio < 0.6, (
        f"Expected compression ratio < 0.6, but got {compression_ratio:.3f} "
        f"(compressed: {compressed_memory} bytes, uncompressed: {uncompressed_memory} bytes)"
    )

    # Both should still generate reasonable answers
    assert len(result_uncompressed["answer"]) > 0
    assert len(result_compressed["answer"]) > 0