"vllm/model_executor/models/terratorch.py" did not exist on "f1579b229de7b21a9e4aec34be2fab29b2c84675"
test_logprobs.py 9.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
from itertools import cycle

import pytest
6
import os
7
8
9

from vllm import SamplingParams

10
from ..utils import maybe_enable_chunked_prefill
11
from .conftest import run_equality_correctness_test
12
from ...utils import models_path_prefix
13

14
os.environ["LLAMA_NN"] = "0"
15
16
17
18

@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
zhuwenwen's avatar
zhuwenwen committed
19
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
20
21

        # Skip cuda graph recording for fast test.
22
        "enforce_eager": True
23
24
25
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
26
27
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
28
        "model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
29
30
31
32
33
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}, {
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
34
        "model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
35
36
37
38
        "num_speculative_tokens": 3,
        "disable_logprobs": True,
    },
}])
39
40
41
42
43
44
45
46
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        7,
    ])
@pytest.mark.parametrize("seed", [1])
47
@pytest.mark.parametrize("logprobs", [1, 6])
48
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 8])
49
50
51
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
52
53
54
                           seed: int, logprobs: int, prefill_chunk_size: int):
    """Verify output logprobs are equal with and without speculative decoding,
        as well as with and without chunked prefill.
55
    """
56
    maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        prompt_logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
71
72
73
74
75


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
76
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
77
78
79
80
81
82

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
83
84
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
85
        "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
86
87
88
89
90
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}, {
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
91
        "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
92
93
94
95
        "num_speculative_tokens": 6,
        "disable_logprobs": False,
    },
}])
96
97
98
99
100
101
102
103
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
104
105
106
107
108
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
                              per_test_common_llm_kwargs, baseline_llm_kwargs,
                              test_llm_kwargs, batch_size: int,
                              output_len: int, seed: int, logprobs: int):
109
110
    """Veriy logprob greedy equality with different speculation lens.
    """
111
112
113
114
115
116
117
118
119
120
121
122
123
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
124
125
126
127
128


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
129
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
130
131
132
133
134
135
136
137
138

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [{
139
        "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
140
            "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
141
142
143
144
145
146
147
            "num_speculative_tokens": 3,
            "disable_logprobs": False,
            # Artificially limit the draft model max model len; this forces
            # vLLM to skip speculation once the sequences grow beyond 32-k
            # tokens.
            "max_model_len": 32,
        },
148
149
150
151
152
153
154
155
156
    }])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
157
158
159
160
161
162
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
                                        per_test_common_llm_kwargs,
                                        baseline_llm_kwargs, test_llm_kwargs,
                                        batch_size: int, output_len: int,
                                        seed: int, logprobs: int):
163
164
    """Verify logprobs greedy equality when some sequences skip speculation.
    """
165
166
167
168
169
170
171
172
173
174
175
176
177
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])
178
179
180
181
182


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
183
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-68m"),
184
185
186
187
188
189

        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
190
191
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
192
        "model": os.path.join(models_path_prefix, "JackFram/llama-160m"),
193
194
195
196
        "num_speculative_tokens": 3,
        "disable_logprobs": False,
    },
}])
197
198
199
200
201
202
203
204
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("seed", [1])
205
206
207
208
209
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
                         per_test_common_llm_kwargs, baseline_llm_kwargs,
                         test_llm_kwargs, batch_size: int, output_len: int,
                         seed: int, logprobs: int):
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    """Verify at least one logprob result has num_logprobs+1, which tests the
    case where the sampled token is not in top-k logprobs.

    Ideally, this test should validate equality with non-spec by getting
    logprobs. This is left as future improvement.
    """
    temperature = 1.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
        "San Francisco is know for its",
        "Facebook was created in 2004 by",
        "Curious George is a",
        "Python 3.11 brings improvements to its",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
232
233
        max_tokens=output_len,
        ignore_eos=True,
234
        temperature=temperature,
235
        logprobs=logprobs,
236
237
    )

238
239
240
241
242
243
244
245
    sd_args = {
        **common_llm_kwargs,
        **per_test_common_llm_kwargs,
        **test_llm_kwargs,
    }

    with vllm_runner(**sd_args) as vllm_model:
        sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
246
247

    num_returned_logprobs = [
248
        len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
249
250
251
252
    ]

    # Assert one of the returned logprobs has > num_logprobs (indicating the
    # sampled token is not in top-k).
253
254
    assert any(
        [num_returned > logprobs for num_returned in num_returned_logprobs])
255
256
257
258
259


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
260
        "model_name": os.path.join(models_path_prefix, "JackFram/llama-160m"),
261
262
263
264
265
        # Skip cuda graph recording for fast test.
        "enforce_eager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
266
267
@pytest.mark.parametrize("test_llm_kwargs", [{
    "speculative_config": {
zhuwenwen's avatar
zhuwenwen committed
268
        "model": os.path.join(models_path_prefix, "JackFram/llama-68m"),
269
270
271
272
        "num_speculative_tokens": 3,
        "disable_logprobs": True,
    },
}])
273
@pytest.mark.parametrize("seed", [1])
274
275
276
277
278
279
280
281
282
283
284
285
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
    "output_len",
    [
        # Use smaller output len for fast test.
        32,
    ])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
                           per_test_common_llm_kwargs, baseline_llm_kwargs,
                           test_llm_kwargs, batch_size: int, output_len: int,
                           seed: int, logprobs: int):
286
287
288
    """Check the behavior when logprobs are disabled.
    Token choices should match with the base model.
    """
289
290
291
292
293
294
295
296
297
298
299
300
301
    run_equality_correctness_test(
        vllm_runner,
        common_llm_kwargs,
        per_test_common_llm_kwargs,
        baseline_llm_kwargs,
        test_llm_kwargs,
        batch_size,
        output_len,
        seed,
        temperature=0.0,
        logprobs=logprobs,
        disable_logprobs=test_llm_kwargs["speculative_config"]
        ["disable_logprobs"])