test_logprobs.py 11.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
from itertools import cycle

import pytest
6
import os
7
8
9

from vllm import SamplingParams

10
from ..utils import maybe_enable_chunked_prefill
11
from .conftest import run_equality_correctness_test
12
from ...utils import models_path_prefix
13
14
15
16
17


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
zhuwenwen's avatar
zhuwenwen committed
18
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
19
20

        # Skip cuda graph recording for fast test.
21
        "enforce_eager": True
22
23
24
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
25
26
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
zhuwenwen's avatar
zhuwenwen committed
27
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
28
29
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": False,
30
                         }, {
zhuwenwen's avatar
zhuwenwen committed
31
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
32
33
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": True,
34
                         }])
35
36
37
38
39
40
41
42
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
43
@pytest.mark.parametrize("logprobs", [1, 6])
44
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
45
46
47
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
48
49
50
                           seed: int, logprobs: int, prefill_chunk_size: int):
    """Verify output logprobs are equal with and without speculative decoding,
        as well as with and without chunked prefill.
51
    """
52
    maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
53
54
55
56
57
58
59
60
61
62
63
64
65
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  prompt_logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])
66
67
68
69
70


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
71
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
72
73
74
75
76
77

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
78
79
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
80
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
81
82
83
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": False,
                         }, {
84
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
85
86
87
                             "num_speculative_tokens": 6,
                             "disable_logprobs_during_spec_decoding": False,
                         }])
88
89
90
91
92
93
94
95
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
96
97
98
99
100
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
                              per_test_common_llm_kwargs, baseline_llm_kwargs,
                              test_llm_kwargs, batch_size: int,
                              output_len: int, seed: int, logprobs: int):
101
102
    """Veriy logprob greedy equality with different speculation lens.
    """
103
104
105
106
107
108
109
110
111
112
113
114
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])
115
116
117
118
119


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
120
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
121
122
123
124
125
126
127
128
129

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [{
130
        "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
131
        "num_speculative_tokens": 3,
132
        "disable_logprobs_during_spec_decoding": False,
133
134
135
136
137
138
139
140
141
142
143
144
145

        # Artificially limit the draft model max model len; this forces vLLM
        # to skip speculation once the sequences grow beyond 32-k tokens.
        "speculative_max_model_len": 32,
    }])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
146
147
148
149
150
151
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
                                        per_test_common_llm_kwargs,
                                        baseline_llm_kwargs, test_llm_kwargs,
                                        batch_size: int, output_len: int,
                                        seed: int, logprobs: int):
152
153
    """Verify logprobs greedy equality when some sequences skip speculation.
    """
154
155
156
157
158
159
160
161
162
163
164
165
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])
166
167
168
169
170


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
171
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
172
173
174
175
176
177

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
178
179
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
180
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
181
182
183
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": False,
                         }])
184
185
186
187
188
189
190
191
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
192
193
194
195
196
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
                         per_test_common_llm_kwargs, baseline_llm_kwargs,
                         test_llm_kwargs, batch_size: int, output_len: int,
                         seed: int, logprobs: int):
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    """Verify at least one logprob result has num_logprobs+1, which tests the
    case where the sampled token is not in top-k logprobs.

    Ideally, this test should validate equality with non-spec by getting
    logprobs. This is left as future improvement.
    """
    temperature = 1.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
219
220
        max_tokens=output_len,
        ignore_eos=True,
221
        temperature=temperature,
222
        logprobs=logprobs,
223
224
    )

225
226
227
228
229
230
231
232
    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    with vllm_runner(**sd_args) as vllm_model:
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
233
234

    num_returned_logprobs = [
235
        len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
236
237
238
239
    ]

    # Assert one of the returned logprobs has > num_logprobs (indicating the
    # sampled token is not in top-k).
240
241
    assert any(
        [num_returned > logprobs for num_returned in num_returned_logprobs])
242
243
244
245
246


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
247
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
248
249
250
251
252
253
254
        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
255
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
256
257
258
259
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": True,
                         }])
@pytest.mark.parametrize("seed", [1])
260
261
262
263
264
265
266
267
268
269
270
271
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
                           seed: int, logprobs: int):
272
273
274
    """Check the behavior when logprobs are disabled.
    Token choices should match with the base model.
    """
275
276
277
278
279
280
281
282
283
284
285
286
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])