test_logprobs.py 3.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm import SamplingParams
from vllm.logprobs import FlattenLogprobs

MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
NUM_TOP_LOGPROBS = 5
NUM_PROMPT_LOGPROBS = 7
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("greedy", [True, False])
@pytest.mark.parametrize("flatten_logprobs", [True, False])
def test_ranks(
    vllm_runner,
    model,
    dtype,
    greedy,
    flatten_logprobs,
    example_prompts,
    monkeypatch: pytest.MonkeyPatch,
):
    monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
    with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
        tokenizer = vllm_model.llm.get_tokenizer()
        example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
        sampling_params = SamplingParams(
            temperature=0.0 if greedy else 1.0,
            top_p=1.0,
            max_tokens=MAX_TOKENS,
            logprobs=NUM_TOP_LOGPROBS,
            prompt_logprobs=NUM_PROMPT_LOGPROBS,
        )
        results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)

    assert len(results) == len(example_prompt_tokens)
    for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
        decode_tokens, _, decode_logprobs, prompt_logprobs = result

        # Ensure the return type of logprobs is accurate
        assert isinstance(
            prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
        )
        assert isinstance(
            decode_logprobs, FlattenLogprobs if flatten_logprobs else list
        )

        ########################
        # Check prompt logprobs
        ########################
        assert len(prompt_tokens) == len(prompt_logprobs)
        # No logprob for first prompt token
        assert not prompt_logprobs[0]
        for position, (token, logprobs) in enumerate(
            zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
        ):
            # Ensure logprobs of prompt token is always returned
            logprob = logprobs.get(token)
            assert logprob is not None
            assert logprob.rank >= 1
            # Ensure # of returned logprobs should be
            # either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
            assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
            # Ensure top NUM_PROMPT_LOGPROBS is always extracted
            assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
                {logprob.rank for logprob in logprobs.values()}
            )

        ########################
        # Check sample logprobs
        ########################
        assert len(decode_tokens) == len(decode_logprobs)
        for position, (token, logprobs) in enumerate(
            zip(decode_tokens, decode_logprobs)
        ):
            # Ensure logprobs of chosen token is always returned
            logprob = logprobs.get(token)
            assert logprob is not None
            if greedy:
                # For greedy sampling, all chosen logprob should be top ranked
                assert logprob.rank == 1
            else:
                assert logprob.rank >= 1
            # Ensure # of returned logprobs should be
            # either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
            assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
            # Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
            assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
                {logprob.rank for logprob in logprobs.values()}
            )