test_pipeline.py 8.1 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


import logging

import pytest
import torch
from transformers import AutoTokenizer, DynamicCache, QuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available

from kvpress import ExpectedAttentionPress
from kvpress.pipeline import KVPressTextGenerationPipeline
from tests.fixtures import danube_500m_model  # noqa: F401
from tests.fixtures import kv_press_danube_pipeline  # noqa: F401
from tests.fixtures import unit_test_model  # noqa: F401
from tests.fixtures import kv_press_llama3_2_flash_attn_pipeline, kv_press_unit_test_pipeline  # noqa: F401


def test_pipeline(kv_press_unit_test_pipeline, caplog):  # noqa: F811
    with caplog.at_level(logging.DEBUG):
        context = "This is a test article. It was written on 2022-01-01."
        questions = ["When was this article written?"]
        press = ExpectedAttentionPress(compression_ratio=0.4)
        answers = kv_press_unit_test_pipeline(context, questions=questions, press=press)["answers"]

    assert len(answers) == 1
    assert isinstance(answers[0], str)

    messages = [record.message for record in caplog.records]
    assert "Context Length: 23" in messages, messages
    assert "Compressed Context Length: 13" in messages, messages


def test_pipeline_with_cache(kv_press_unit_test_pipeline):  # noqa: F811
    context = "This is a test article. It was written on 2022-01-01."
    questions = ["When was this article written?"]
    press = ExpectedAttentionPress(compression_ratio=0.4)
    cache = DynamicCache()
    answers = kv_press_unit_test_pipeline(context, questions=questions, press=press, cache=cache)["answers"]

    assert len(answers) == 1
    assert isinstance(answers[0], str)


class TestPipelineFA2:
    @pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available")
    @pytest.mark.skipif(not is_flash_attn_2_available(), reason="flash_attn is not installed")
    @pytest.mark.parametrize("compression_ratio", [0.0, 0.2])
    @pytest.mark.xfail(reason="Known issue (https://github.com/huggingface/transformers/issues/42550)", strict=False)
    def test_pipeline_fa2(self, kv_press_llama3_2_flash_attn_pipeline, compression_ratio):  # noqa: F811
        context = "This is a test article. It was written on 2022-01-01."
        questions = ["Repeat the last sentence"]
        press = ExpectedAttentionPress(compression_ratio=compression_ratio)
        cache = DynamicCache()
        answers = kv_press_llama3_2_flash_attn_pipeline(
            context, questions=questions, press=press, cache=cache, max_new_tokens=6
        )["answers"]

        assert len(answers) == 1
        assert isinstance(answers[0], str)

        kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("sdpa")
        press = ExpectedAttentionPress(compression_ratio=compression_ratio)
        cache = DynamicCache()
        answers_sdpa = kv_press_llama3_2_flash_attn_pipeline(
            context, questions=questions, press=press, cache=cache, max_new_tokens=6
        )["answers"]
        kv_press_llama3_2_flash_attn_pipeline.model.set_attn_implementation("flash_attention_2")

        assert (
            answers_sdpa[0] == answers[0]
        ), f"Answers from SDPA and Flash Attention 2 should be the same. \n{answers_sdpa[0]}\n{answers[0]}"
        assert "This is a test" in answers[0], f"The answer should contain the context sentence, but got {answers[0]}."


@pytest.mark.parametrize("question", ["When was this article written?", ""])
def test_pipeline_single_or_no_question(kv_press_unit_test_pipeline, question, caplog):  # noqa: F811
    with caplog.at_level(logging.DEBUG):
        context = "This is a test article. It was written on 2022-01-01."
        press = ExpectedAttentionPress(compression_ratio=0.4)
        answer = kv_press_unit_test_pipeline(context, question=question, press=press)["answer"]

    assert isinstance(answer, str)

    messages = [record.message for record in caplog.records]
    assert "Context Length: 23" in messages, messages
    assert "Compressed Context Length: 13" in messages, messages


def test_pipeline_no_press_works(kv_press_unit_test_pipeline, caplog):  # noqa: F811
    context = "This is a test article. It was written on 2022-01-01."
    question = "When was this article written?"
    kv_press_unit_test_pipeline(context, question=question)


def test_pipeline_answer_is_correct(danube_500m_model, caplog):  # noqa: F811
    with caplog.at_level(logging.DEBUG):
        answers = generate_answer(danube_500m_model)

    for answer in answers:
        assert answer == "This article was written on January 1, 2022."

    messages = [record.message for record in caplog.records]
    assert "Context Length: 28" in messages
    assert "Compressed Context Length: 16" in messages


@pytest.mark.skipif(not is_optimum_quanto_available(), reason="Optimum Quanto is not available")
def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog):  # noqa: F811
    with caplog.at_level(logging.DEBUG):
        context = "This is a test article. It was written on 2022-01-01."
        questions = ["When was this article written?"]
        press = ExpectedAttentionPress(compression_ratio=0.4)
        cache = QuantizedCache(backend="quanto", config=kv_press_danube_pipeline.model.config, nbits=4)
        answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]

    assert len(answers) == 1
    assert isinstance(answers[0], str)

    for answer in answers:
        assert answer == "This article was written on January 1, 2022."

    messages = [record.message for record in caplog.records]
    assert "Context Length: 28" in messages
    assert "Compressed Context Length: 16" in messages


def test_pipeline_compresses_context(unit_test_model, caplog):  # noqa: F811
    with caplog.at_level(logging.DEBUG):
        answers = generate_answer(unit_test_model)

    assert len(answers) == 2
    assert isinstance(answers[0], str)

    messages = [record.message for record in caplog.records]
    assert "Context Length: 23" in messages, messages
    assert "Compressed Context Length: 13" in messages, messages


@torch.no_grad()
def test_pipeline_context_cache_is_invariant(unit_test_model):  # noqa: F811
    model = unit_test_model
    questions = ["When was this article written?"]
    tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
    device = model.device

    compression_pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)
    input_ids_question = tokenizer(questions[0], return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)

    seq_len = 256
    past_key_values: DynamicCache = model(
        input_ids=torch.randint(0, 1000, (1, seq_len), device=device), past_key_values=DynamicCache()
    ).past_key_values
    assert past_key_values.get_seq_length() == seq_len

    keys = [layer.keys.clone() for layer in past_key_values.layers]
    values = [layer.values.clone() for layer in past_key_values.layers]
    cache_seq_lengths = [past_key_values.get_seq_length(layer_idx) for layer_idx in range(len(past_key_values))]
    compression_pipeline.generate_answer(input_ids_question, past_key_values, context_length=22, max_new_tokens=10)
    compression_pipeline._remove_answer_from_cache(past_key_values, cache_seq_lengths)
    assert past_key_values.get_seq_length() == seq_len
    assert all([torch.allclose(key, layer.keys) for key, layer in zip(keys, past_key_values.layers)])
    assert all([torch.allclose(value, layer.values) for value, layer in zip(values, past_key_values.layers)])


def generate_answer(model):
    device = model.device
    context = "This is a test article. It was written on 2022-01-01."
    questions = ["When was this article written?", "When was this article written?"]
    press = ExpectedAttentionPress(compression_ratio=0.4)
    tokenizer = AutoTokenizer.from_pretrained(model.config.name_or_path)
    answers = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer, device=device)(
        context, questions=questions, press=press
    )["answers"]
    return answers