test_logprobs.py 9.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
from itertools import cycle

import pytest

from vllm import SamplingParams

10
from ..utils import maybe_enable_chunked_prefill
11
from .conftest import run_equality_correctness_test
12
13
14
15
16


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
17
        "model_name": "JackFram/llama-160m",
18
19

        # Skip cuda graph recording for fast test.
20
21
22
23
        "enforce_eager": True,

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
24
25
26
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
27
28
29
30
31
32
33
34
35
36
37
38
39
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
        "model": "JackFram/llama-68m",
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}, {
    "speculative_config": {
        "model": "JackFram/llama-68m",
        "num_speculative_tokens": 3,
        "disable_logprobs": True,
    },
}])
40
41
42
43
44
45
46
47
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
48
@pytest.mark.parametrize("logprobs", [1, 6])
49
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
50
51
52
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
53
54
55
                           seed: int, logprobs: int, prefill_chunk_size: int):
    """Verify output logprobs are equal with and without speculative decoding,
        as well as with and without chunked prefill.
56
    """
57
    maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        prompt_logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
72
73
74
75
76


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
77
        "model_name": "JackFram/llama-68m",
78
79
80

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
81
82
83

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
84
85
86
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
87
88
89
90
91
92
93
94
95
96
97
98
99
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
        "model": "JackFram/llama-160m",
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}, {
    "speculative_config": {
        "model": "JackFram/llama-160m",
        "num_speculative_tokens": 6,
        "disable_logprobs": False,
    },
}])
100
101
102
103
104
105
106
107
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
108
109
110
111
112
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
                              per_test_common_llm_kwargs, baseline_llm_kwargs,
                              test_llm_kwargs, batch_size: int,
                              output_len: int, seed: int, logprobs: int):
113
114
    """Veriy logprob greedy equality with different speculation lens.
    """
115
116
117
118
119
120
121
122
123
124
125
126
127
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
128
129
130
131
132


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
133
        "model_name": "JackFram/llama-68m",
134
135
136

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
137
138
139

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
140
141
142
143
144
145
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [{
146
147
148
149
150
151
152
153
154
        "speculative_config": {
            "model": "JackFram/llama-160m",
            "num_speculative_tokens": 3,
            "disable_logprobs": False,
            # Artificially limit the draft model max model len; this forces
            # vLLM to skip speculation once the sequences grow beyond 32-k
            # tokens.
            "max_model_len": 32,
        },
155
156
157
158
159
160
161
162
163
    }])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
164
165
166
167
168
169
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
                                        per_test_common_llm_kwargs,
                                        baseline_llm_kwargs, test_llm_kwargs,
                                        batch_size: int, output_len: int,
                                        seed: int, logprobs: int):
170
171
    """Verify logprobs greedy equality when some sequences skip speculation.
    """
172
173
174
175
176
177
178
179
180
181
182
183
184
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
185
186
187
188
189


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
190
        "model_name": "JackFram/llama-68m",
191
192
193

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
194
195
196

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
197
198
199
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
200
201
202
203
204
205
206
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
        "model": "JackFram/llama-160m",
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}])
207
208
209
210
211
212
213
214
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
215
216
217
218
219
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
                         per_test_common_llm_kwargs, baseline_llm_kwargs,
                         test_llm_kwargs, batch_size: int, output_len: int,
                         seed: int, logprobs: int):
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    """Verify at least one logprob result has num_logprobs+1, which tests the
    case where the sampled token is not in top-k logprobs.

    Ideally, this test should validate equality with non-spec by getting
    logprobs. This is left as future improvement.
    """
    temperature = 1.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
242
243
        max_tokens=output_len,
        ignore_eos=True,
244
        temperature=temperature,
245
        logprobs=logprobs,
246
247
    )

248
249
250
251
252
253
254
255
    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    with vllm_runner(**sd_args) as vllm_model:
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
256
257

    num_returned_logprobs = [
258
        len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
259
260
261
262
    ]

    # Assert one of the returned logprobs has > num_logprobs (indicating the
    # sampled token is not in top-k).
263
264
    assert any(
        [num_returned > logprobs for num_returned in num_returned_logprobs])
265
266
267
268
269


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
270
        "model_name": "JackFram/llama-160m",
271

272
273
        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
274
275
276

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
277
278
279
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
280
281
282
283
284
285
286
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
        "model": "JackFram/llama-68m",
        "num_speculative_tokens": 3,
        "disable_logprobs": True,
    },
}])
287
@pytest.mark.parametrize("seed", [1])
288
289
290
291
292
293
294
295
296
297
298
299
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
                           seed: int, logprobs: int):
300
301
302
    """Check the behavior when logprobs are disabled.
    Token choices should match with the base model.
    """
303
304
305
306
307
308
309
310
311
312
313
314
315
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])