test_logprobs.py 11.4 KB
Newer Older
1
2
3
from itertools import cycle

import pytest
4
import os
5
6
7

from vllm import SamplingParams

8
from .conftest import run_equality_correctness_test
9
from ...utils import models_path_prefix
10
11
12
13
14


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
15
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
16
17
18
19
20
21
22
23
24

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
25
26
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
27
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
28
29
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": False,
30
                         }, {
31
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
32
33
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": True,
34
                         }])
35
36
37
38
39
40
41
42
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
43
44
45
46
47
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
                           seed: int, logprobs: int):
48
49
    """Verify output logprobs are equal with and without speculative decoding.
    """
50
51
52
53
54
55
56
57
58
59
60
61
62
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  prompt_logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])
63
64
65
66
67


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
68
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
69
70
71
72
73
74
75
76
77

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
78
79
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
80
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
81
82
83
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": False,
                         }, {
84
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
85
86
87
                             "num_speculative_tokens": 6,
                             "disable_logprobs_during_spec_decoding": False,
                         }])
88
89
90
91
92
93
94
95
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
96
97
98
99
100
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
                              per_test_common_llm_kwargs, baseline_llm_kwargs,
                              test_llm_kwargs, batch_size: int,
                              output_len: int, seed: int, logprobs: int):
101
102
    """Veriy logprob greedy equality with different speculation lens.
    """
103
104
105
106
107
108
109
110
111
112
113
114
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])
115
116
117
118
119


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
120
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
121
122
123
124
125
126
127
128
129
130
131
132

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [{
133
        "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
134
        "num_speculative_tokens": 3,
135
        "disable_logprobs_during_spec_decoding": False,
136
137
138
139
140
141
142
143
144
145
146
147
148

        # Artificially limit the draft model max model len; this forces vLLM
        # to skip speculation once the sequences grow beyond 32-k tokens.
        "speculative_max_model_len": 32,
    }])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
149
150
151
152
153
154
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
                                        per_test_common_llm_kwargs,
                                        baseline_llm_kwargs, test_llm_kwargs,
                                        batch_size: int, output_len: int,
                                        seed: int, logprobs: int):
155
156
    """Verify logprobs greedy equality when some sequences skip speculation.
    """
157
158
159
160
161
162
163
164
165
166
167
168
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])
169
170
171
172
173


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
174
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
175
176
177
178
179
180
181
182
183

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,

        # Required for spec decode.
        "use_v2_block_manager": True
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
184
185
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
186
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
187
188
189
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": False,
                         }])
190
191
192
193
194
195
196
197
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
198
199
200
201
202
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
                         per_test_common_llm_kwargs, baseline_llm_kwargs,
                         test_llm_kwargs, batch_size: int, output_len: int,
                         seed: int, logprobs: int):
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    """Verify at least one logprob result has num_logprobs+1, which tests the
    case where the sampled token is not in top-k logprobs.

    Ideally, this test should validate equality with non-spec by getting
    logprobs. This is left as future improvement.
    """
    temperature = 1.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
225
226
        max_tokens=output_len,
        ignore_eos=True,
227
        temperature=temperature,
228
        logprobs=logprobs,
229
230
    )

231
232
233
234
235
236
237
238
    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    with vllm_runner(**sd_args) as vllm_model:
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
239
240

    num_returned_logprobs = [
241
        len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
242
243
244
245
    ]

    # Assert one of the returned logprobs has > num_logprobs (indicating the
    # sampled token is not in top-k).
246
247
    assert any(
        [num_returned > logprobs for num_returned in num_returned_logprobs])
248
249
250
251
252


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
253
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
254
255
256
257
258
259
260
261
262
        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
        # Required for spec decode.
        "use_v2_block_manager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
                         [{
263
                             "speculative_model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
264
265
266
267
                             "num_speculative_tokens": 3,
                             "disable_logprobs_during_spec_decoding": True,
                         }])
@pytest.mark.parametrize("seed", [1])
268
269
270
271
272
273
274
275
276
277
278
279
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
                           seed: int, logprobs: int):
280
281
282
    """Check the behavior when logprobs are disabled.
    Token choices should match with the base model.
    """
283
284
285
286
287
288
289
290
291
292
293
294
    run_equality_correctness_test(vllm_runner,
                                  common_llm_kwargs,
                                  per_test_common_llm_kwargs,
                                  baseline_llm_kwargs,
                                  test_llm_kwargs,
                                  batch_size,
                                  output_len,
                                  seed,
                                  temperature=0.0,
                                  logprobs=logprobs,
                                  disable_logprobs=test_llm_kwargs[
                                      'disable_logprobs_during_spec_decoding'])