test_logprobs.py 5.11 KB
Newer Older
1
2
3
import pytest
import torch

4
from tests.conftest import VllmRunner
5
6
7
8
9
10
11
from vllm import SamplingParams

MODELS = ["facebook/opt-125m"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
12
13
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6])  # 32000 == vocab_size
14
15
16
17
18
def test_get_prompt_logprobs(
    hf_runner,
    vllm_runner,
    model,
    dtype,
19
20
    chunked_prefill_token_size: int,
    num_top_logprobs: int,
21
22
    example_prompts,
):
23
24
25
26
27
28
29
30
    max_num_seqs = 256
    enable_chunked_prefill = False
    max_num_batched_tokens = None
    if chunked_prefill_token_size != -1:
        enable_chunked_prefill = True
        max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
        max_num_batched_tokens = chunked_prefill_token_size

31
32
33
34
35
36
37
38
    max_tokens = 5
    hf_model = hf_runner(model, dtype=dtype)
    hf_logprobs = hf_model.generate_greedy_logprobs(
        example_prompts,
        max_tokens=max_tokens,
    )
    del hf_model

39
40
41
42
43
44
45
46
    vllm_model = vllm_runner(
        model,
        dtype=dtype,
        max_logprobs=num_top_logprobs,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
        max_num_seqs=max_num_seqs,
    )
47
    vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
48
                                          logprobs=num_top_logprobs,
49
                                          prompt_logprobs=num_top_logprobs,
50
51
52
53
54
55
56
57
                                          temperature=0.0)
    vllm_results = vllm_model.model.generate(
        example_prompts, sampling_params=vllm_sampling_params)

    # Test whether logprobs are included in the results.
    for result in vllm_results:
        assert result.prompt_logprobs is not None
        assert result.outputs[0].logprobs is not None
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        assert len(result.outputs[0].logprobs) == max_tokens
        for logprobs in result.outputs[0].logprobs:
            assert len(logprobs) == num_top_logprobs
        output_text = result.outputs[0].text
        output_string_from_most_likely_tokens = []
        for top_logprobs in result.outputs[0].logprobs:
            top_logprob = next(iter(top_logprobs.values()))
            output_string_from_most_likely_tokens.append(
                top_logprob.decoded_token)
        output_string_from_most_likely_tokens = "".join(
            output_string_from_most_likely_tokens)
        assert output_text == output_string_from_most_likely_tokens, (
            "The output text from the top logprob for each token position "
            "should be the same as the output text in the result.")
72

73
74
75
76
77
78
79
80
        # The first prompt logprob is always None
        assert result.prompt_logprobs[0] is None
        for prompt_logprobs in result.prompt_logprobs[1:]:
            # If the prompt token is not included in the top X
            # logprob, it can return 1 more data
            assert (len(prompt_logprobs) == num_top_logprobs
                    or len(prompt_logprobs) == num_top_logprobs + 1)

81
82
83
    # Test whether prompt logprobs are consistent with HF
    for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
        # Check prompt logprobs
84
        # The first prompt logprob is always None, so we compare it from 1:.
85
86
87
        vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
        for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
            for token_id, logprob in vllm_prompt_logprob_dict.items():
88
                torch.testing.assert_close(logprob.logprob,
89
90
91
92
                                           hf_logprob[0][i][token_id].item(),
                                           atol=1e-2,
                                           rtol=1e-2)
        vllm_sample_logprobs = vllm_result.outputs[0].logprobs
93
94
95
        for i, top_logprobs in enumerate(vllm_sample_logprobs):
            for token_id, sample_logprob in top_logprobs.items():
                logprob = sample_logprob.logprob
96
97
98
99
                torch.testing.assert_close(logprob,
                                           hf_logprob[i][-1][token_id].item(),
                                           atol=1e-2,
                                           rtol=1e-2)
100
101
                assert isinstance(sample_logprob.decoded_token, str), (
                    "The token should be decoded by the time it is returned "
102
103
                    " to the user.")

104
105
106
107
108
109
110
111
112
113
114
    # Test if prompt logprobs are correctly set.
    for vllm_result in vllm_results:
        token_ids = vllm_result.prompt_token_ids
        prompt_logprobs = vllm_result.prompt_logprobs

        # The first token doesn't have logprob.
        assert prompt_logprobs[0] is None

        for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
            assert token_id in logprob_dict

115
116
117
118
119
120
121
122
123
124

def test_max_logprobs():
    runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
    vllm_sampling_params = SamplingParams(logprobs=1)
    # should pass
    runner.generate(["Hello world"], sampling_params=vllm_sampling_params)

    bad_sampling_params = SamplingParams(logprobs=2)
    with pytest.raises(ValueError):
        runner.generate(["Hello world"], sampling_params=bad_sampling_params)