test_logprobs.py 7.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest
import torch
6
import os
7
8
9

from vllm import SamplingParams

10
from ..conftest import VllmRunner
11
from ..utils import models_path_prefix
12

zhuwenwen's avatar
zhuwenwen committed
13
MODELS = [os.path.join(models_path_prefix, "distilbert/distilgpt2")]
14
15


16
17
18
19
20
21
22
23
24
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    This module is V0 only since it uses dtype=float, so
    set VLLM_USE_V1=0 for all tests in the module.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


25
@pytest.mark.parametrize("model", MODELS)
26
@pytest.mark.parametrize("dtype",
27
                         ["half"])  # needed for comparing logprobs with HF
28
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
29
@pytest.mark.parametrize("num_top_logprobs", [0, 6])  # 32000 == vocab_size
30
@pytest.mark.parametrize("detokenize", [True, False])
31
32
33
34
35
def test_get_prompt_logprobs(
    hf_runner,
    vllm_runner,
    model,
    dtype,
36
37
    chunked_prefill_token_size: int,
    num_top_logprobs: int,
38
    detokenize: bool,
39
40
    example_prompts,
):
41
42
43
44
45
46
47
48
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

49
    max_tokens = 5
50
51
52
53
54
    with hf_runner(model, dtype=dtype) as hf_model:
        hf_logprobs = hf_model.generate_greedy_logprobs(
            example_prompts,
            max_tokens=max_tokens,
        )
55

56
57
58
59
60
61
62
63
64
65
66
67
68
    with vllm_runner(
            model,
            dtype=dtype,
            max_logprobs=num_top_logprobs,
            enable_chunked_prefill=enable_chunked_prefill,
            max_num_batched_tokens=max_num_batched_tokens,
            max_num_seqs=max_num_seqs,
    ) as vllm_model:
        vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
                                              logprobs=num_top_logprobs,
                                              prompt_logprobs=num_top_logprobs,
                                              temperature=0.0,
                                              detokenize=detokenize)
69
        vllm_results = vllm_model.llm.generate(
70
            example_prompts, sampling_params=vllm_sampling_params)
71
72
73
74
75

    # Test whether logprobs are included in the results.
    for result in vllm_results:
        assert result.prompt_logprobs is not None
        assert result.outputs[0].logprobs is not None
76
77
        assert len(result.outputs[0].logprobs) == max_tokens
        for logprobs in result.outputs[0].logprobs:
78
79
80
81
            # If the output token is not included in the top X
            # logprob, it can return 1 more data
            assert (len(logprobs) == num_top_logprobs
                    or len(logprobs) == num_top_logprobs + 1)
82
        output_text = result.outputs[0].text
83
        output_string_from_most_likely_tokens_lst: list[str] = []
84
85
        for top_logprobs in result.outputs[0].logprobs:
            top_logprob = next(iter(top_logprobs.values()))
86
            output_string_from_most_likely_tokens_lst.append(
87
                top_logprob.decoded_token)
88
89
90

        if detokenize:
            output_string_from_most_likely_tokens = "".join(
91
                output_string_from_most_likely_tokens_lst)
92
93
94
95
96
            assert output_text == output_string_from_most_likely_tokens, (
                "The output text from the top logprob for each token position "
                "should be the same as the output text in the result.")
        else:
            assert output_text == ''
97
98
            assert output_string_from_most_likely_tokens_lst == ([None] *
                                                                 max_tokens)
99

100
101
102
103
104
105
106
107
        # The first prompt logprob is always None
        assert result.prompt_logprobs[0] is None
        for prompt_logprobs in result.prompt_logprobs[1:]:
            # If the prompt token is not included in the top X
            # logprob, it can return 1 more data
            assert (len(prompt_logprobs) == num_top_logprobs
                    or len(prompt_logprobs) == num_top_logprobs + 1)

108
109
110
    # Test whether prompt logprobs are consistent with HF
    for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
        # Check prompt logprobs
111
        # The first prompt logprob is always None, so we compare it from 1:.
112
113
114
        vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
        for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
            for token_id, logprob in vllm_prompt_logprob_dict.items():
115
                torch.testing.assert_close(logprob.logprob,
116
117
118
119
                                           hf_logprob[0][i][token_id].item(),
                                           atol=1e-2,
                                           rtol=1e-2)
        vllm_sample_logprobs = vllm_result.outputs[0].logprobs
120
121
122
        for i, top_logprobs in enumerate(vllm_sample_logprobs):
            for token_id, sample_logprob in top_logprobs.items():
                logprob = sample_logprob.logprob
123
124
                torch.testing.assert_close(logprob,
                                           hf_logprob[i][-1][token_id].item(),
125
126
                                           atol=1e-1,
                                           rtol=1e-1)
127
128
129
130
                if detokenize:
                    assert isinstance(sample_logprob.decoded_token, str), (
                        "The token should be decoded by the time it is returned"
                        " to the user.")
131

132
133
134
135
136
137
138
139
140
141
142
    # Test if prompt logprobs are correctly set.
    for vllm_result in vllm_results:
        token_ids = vllm_result.prompt_token_ids
        prompt_logprobs = vllm_result.prompt_logprobs

        # The first token doesn't have logprob.
        assert prompt_logprobs[0] is None

        for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
            assert token_id in logprob_dict

143
144

def test_max_logprobs():
145
    runner = VllmRunner(os.path.join(models_path_prefix, "facebook/opt-125m"), max_logprobs=1)
146
147
148
149
150
151
152
    vllm_sampling_params = SamplingParams(logprobs=1)
    # should pass
    runner.generate(["Hello world"], sampling_params=vllm_sampling_params)

    bad_sampling_params = SamplingParams(logprobs=2)
    with pytest.raises(ValueError):
        runner.generate(["Hello world"], sampling_params=bad_sampling_params)
zhuwenwen's avatar
zhuwenwen committed
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("detokenize", [True, False])
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
                       detokenize: bool, example_prompts):
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size
    max_tokens = 5

    with vllm_runner(
            model,
            enable_chunked_prefill=enable_chunked_prefill,
            max_num_batched_tokens=max_num_batched_tokens,
            max_num_seqs=max_num_seqs,
    ) as vllm_model:
        sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
                                                       logprobs=None,
                                                       temperature=0.0,
                                                       detokenize=detokenize)
179
        results_logprobs_none = vllm_model.llm.generate(
180
181
182
183
184
            example_prompts, sampling_params=sampling_params_logprobs_none)

    for i in range(len(results_logprobs_none)):
        assert results_logprobs_none[i].outputs[0].logprobs is None
        assert results_logprobs_none[i].outputs[0].cumulative_logprob is None