untest_logprobs.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
from itertools import cycle

import pytest
7
import os
8
9
10

from vllm import SamplingParams

11
from ..utils import maybe_enable_chunked_prefill
12
from .conftest import run_equality_correctness_test
13
from ...utils import models_path_prefix
14

15
os.environ["LLAMA_NN"] = "0"
16
17
18
19

@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
zhuwenwen's avatar
zhuwenwen committed
20
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
21
22

        # Skip cuda graph recording for fast test.
23
24
25
26
        "enforce_eager": True,

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
27
28
29
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
30
31
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
32
        "model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
33
34
35
36
37
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}, {
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
38
        "model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
39
40
41
42
        "num_speculative_tokens": 3,
        "disable_logprobs": True,
    },
}])
43
44
45
46
47
48
49
50
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
51
@pytest.mark.parametrize("logprobs", [1, 6])
52
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 8])
53
54
55
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
56
57
58
                           seed: int, logprobs: int, prefill_chunk_size: int):
    """Verify output logprobs are equal with and without speculative decoding,
        as well as with and without chunked prefill.
59
    """
60
    maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        prompt_logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
75
76
77
78
79


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
80
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
81
82
83

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
84
85
86

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
87
88
89
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
90
91
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
92
        "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
93
94
95
96
97
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}, {
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
98
        "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
99
100
101
102
        "num_speculative_tokens": 6,
        "disable_logprobs": False,
    },
}])
103
104
105
106
107
108
109
110
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
111
112
113
114
115
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
                              per_test_common_llm_kwargs, baseline_llm_kwargs,
                              test_llm_kwargs, batch_size: int,
                              output_len: int, seed: int, logprobs: int):
116
117
    """Veriy logprob greedy equality with different speculation lens.
    """
118
119
120
121
122
123
124
125
126
127
128
129
130
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
131
132
133
134
135


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
136
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
137
138
139

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
140
141
142

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
143
144
145
146
147
148
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [{
149
        "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
150
            "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
151
152
153
154
155
156
157
            "num_speculative_tokens": 3,
            "disable_logprobs": False,
            # Artificially limit the draft model max model len; this forces
            # vLLM to skip speculation once the sequences grow beyond 32-k
            # tokens.
            "max_model_len": 32,
        },
158
159
160
161
162
163
164
165
166
    }])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
167
168
169
170
171
172
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
                                        per_test_common_llm_kwargs,
                                        baseline_llm_kwargs, test_llm_kwargs,
                                        batch_size: int, output_len: int,
                                        seed: int, logprobs: int):
173
174
    """Verify logprobs greedy equality when some sequences skip speculation.
    """
175
176
177
178
179
180
181
182
183
184
185
186
187
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
188
189
190
191
192


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
193
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
194
195
196

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
197
198
199

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
200
201
202
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
203
204
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
205
        "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
206
207
208
209
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}])
210
211
212
213
214
215
216
217
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
218
219
220
221
222
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
                         per_test_common_llm_kwargs, baseline_llm_kwargs,
                         test_llm_kwargs, batch_size: int, output_len: int,
                         seed: int, logprobs: int):
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    """Verify at least one logprob result has num_logprobs+1, which tests the
    case where the sampled token is not in top-k logprobs.

    Ideally, this test should validate equality with non-spec by getting
    logprobs. This is left as future improvement.
    """
    temperature = 1.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
245
246
        max_tokens=output_len,
        ignore_eos=True,
247
        temperature=temperature,
248
        logprobs=logprobs,
249
250
    )

251
252
253
254
255
256
257
258
    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    with vllm_runner(**sd_args) as vllm_model:
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
259
260

    num_returned_logprobs = [
261
        len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
262
263
264
265
    ]

    # Assert one of the returned logprobs has > num_logprobs (indicating the
    # sampled token is not in top-k).
266
267
    assert any(
        [num_returned > logprobs for num_returned in num_returned_logprobs])
268
269
270
271
272


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
zhuwenwen's avatar
zhuwenwen committed
273
        "model_name": os.path.join(models_path_prefix,"JackFram/llama-160m"),
274

275
276
        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
277
278
279

        # The original model is float32, keep it for numerical stability.
        "dtype": "float32",
280
281
282
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
283
284
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
285
        "model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
286
287
288
289
        "num_speculative_tokens": 3,
        "disable_logprobs": True,
    },
}])
290
@pytest.mark.parametrize("seed", [1])
291
292
293
294
295
296
297
298
299
300
301
302
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
                           seed: int, logprobs: int):
303
304
305
    """Check the behavior when logprobs are disabled.
    Token choices should match with the base model.
    """
306
307
308
309
310
311
312
313
314
315
316
317
318
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])