test_logprobs.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import math
from itertools import cycle

import pytest

from vllm import SamplingParams

from .conftest import get_logprobs_from_llm_generator


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        "model": "JackFram/llama-68m",

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True,
        "max_logprobs": 6,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_model": "JackFram/llama-160m",
    "num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
                           batch_size: int, output_len: int):
    """Verify output logprobs are equal with and without speculative decoding.
    """
    run_greedy_logprobs_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len=output_len,
                                         force_output_len=True)


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        "model": "JackFram/llama-68m",

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True,
        "max_logprobs": 6,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_model": "JackFram/llama-160m",
    "num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("num_logprobs", [6])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
                           batch_size: int, output_len: int,
                           num_logprobs: int):
    """Verify output logprobs are equal with and without spec decode.
    This specifies a number of logprobs >1.
    """
    run_greedy_logprobs_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len=output_len,
                                         force_output_len=True,
                                         logprob_rank=num_logprobs)


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        "model": "JackFram/llama-68m",

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_model": "JackFram/llama-160m",
    "num_speculative_tokens": 3,
}, {
    "speculative_model": "JackFram/llama-160m",
    "num_speculative_tokens": 6,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
                              batch_size: int, output_len: int):
    """Veriy logprob greedy equality with different speculation lens.
    """
    run_greedy_logprobs_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len=output_len,
                                         force_output_len=True)


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        "model": "JackFram/llama-68m",

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [{
        "speculative_model": "JackFram/llama-160m",
        "num_speculative_tokens": 3,

        # Artificially limit the draft model max model len; this forces vLLM
        # to skip speculation once the sequences grow beyond 32-k tokens.
        "speculative_max_model_len": 32,
    }])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_when_skip_speculation(baseline_llm_generator,
                                        test_llm_generator, batch_size: int,
                                        output_len: int):
    """Verify logprobs greedy equality when some sequences skip speculation.
    """
    run_greedy_logprobs_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len=output_len,
                                         force_output_len=True)


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        "model": "JackFram/llama-68m",

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_model": "JackFram/llama-160m",
    "num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
                         batch_size: int, output_len: int):
    """Verify at least one logprob result has num_logprobs+1, which tests the
    case where the sampled token is not in top-k logprobs.

    Ideally, this test should validate equality with non-spec by getting
    logprobs. This is left as future improvement.
    """
    batch_size = 8
    max_output_len = output_len
    force_output_len = True
    logprob_rank = 5

    temperature = 1.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    # If the test requires that we generated max_output_len tokens, then set the
    # sampling params to ignore eos token.
    ignore_eos = force_output_len

    sampling_params = SamplingParams(
        max_tokens=max_output_len,
        ignore_eos=ignore_eos,
        temperature=temperature,
        logprobs=logprob_rank,
    )

    spec_batch_logprobs = get_logprobs_from_llm_generator(
        test_llm_generator, prompts, sampling_params)

    num_returned_logprobs = [
        len(logprob_dict) for seq_logprobs in spec_batch_logprobs
        for logprob_dict in seq_logprobs
    ]

    # Assert one of the returned logprobs has > num_logprobs (indicating the
    # sampled token is not in top-k).
    assert any([
        num_returned > logprob_rank for num_returned in num_returned_logprobs
    ])


def run_greedy_logprobs_correctness_test(baseline_llm_generator,
                                         test_llm_generator,
                                         batch_size,
                                         max_output_len,
                                         force_output_len: bool,
                                         logprob_rank: int = 1):
    """Helper method that compares the logprobs outputs of both the baseline LLM
    and the test LLM. It asserts greedy equality of the logprobs when the
    temperature is zero.
    """
    temperature = 0.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    # If the test requires that we generated max_output_len tokens, then set the
    # sampling params to ignore eos token.
    ignore_eos = force_output_len

    sampling_params = SamplingParams(
        max_tokens=max_output_len,
        ignore_eos=ignore_eos,
        temperature=temperature,
        logprobs=logprob_rank,
    )

    spec_batch_logprobs = get_logprobs_from_llm_generator(
        test_llm_generator, prompts, sampling_params)
    baseline_batch_logprobs = get_logprobs_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    assert len(baseline_batch_logprobs) == len(prompts)
    assert len(spec_batch_logprobs) == len(prompts)

    # For each sequence in the batch.
    for i, (baseline_logprobs, spec_logprobs) in enumerate(
            zip(baseline_batch_logprobs, spec_batch_logprobs)):
        assert len(spec_logprobs) == len(baseline_logprobs)

        # For each generated position of the sequence.
        for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
                zip(spec_logprobs, baseline_logprobs)):

            # Map rank to token/logprob in spec output.
            spec_rank_to_token_id = {
                value.rank: key
                for key, value in spec_pos_logprobs.items()
            }
            spec_rank_to_logprob = {
                value.rank: value.logprob
                for key, value in spec_pos_logprobs.items()
            }

            # Map rank to token/logprob in baseline output.
            baseline_rank_to_token_id = {
                value.rank: key
                for key, value in baseline_pos_logprobs.items()
            }
            baseline_rank_to_logprob = {
                value.rank: value.logprob
                for key, value in baseline_pos_logprobs.items()
            }

            # Assert set of ranks returned is equal.
            assert set(spec_rank_to_token_id.keys()) == set(
                baseline_rank_to_token_id.keys())

            # Assert each logprob/token id is correct, keyed by rank.
            for rank in sorted(set(spec_rank_to_token_id.keys())):
                assert spec_rank_to_token_id[
                    rank] == baseline_rank_to_token_id[rank], f"{rank}"
                assert math.isclose(
                    a=spec_rank_to_logprob[rank],
                    b=baseline_rank_to_logprob[rank],
                    abs_tol=1e-1,
                )