utils.py 5.46 KB
Newer Older
1
import warnings
2
from enum import Enum
3
4
5
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import SampleLogprobs
6
7
8
9

TokensText = Tuple[List[int], str]


10
11
12
13
14
15
16
def check_outputs_equal(
    *,
    outputs_0_lst: Sequence[TokensText],
    outputs_1_lst: Sequence[TokensText],
    name_0: str,
    name_1: str,
):
17
18
19
20
21
22
23
24
25
26
27
28
    """
    Compare the two sequences generated by different models, 
    which should be equal.
    """
    assert len(outputs_0_lst) == len(outputs_1_lst)

    for prompt_idx, (outputs_0,
                     outputs_1) in enumerate(zip(outputs_0_lst,
                                                 outputs_1_lst)):
        output_ids_0, output_str_0 = outputs_0
        output_ids_1, output_str_1 = outputs_1

29
30
31
32
33
34
35
        # The text and token outputs should exactly match
        fail_msg = (f"Test{prompt_idx}:"
                    f"\n{name_0}:\t{output_str_0!r}"
                    f"\n{name_1}:\t{output_str_1!r}")

        assert output_str_0 == output_str_1, fail_msg
        assert output_ids_0 == output_ids_1, fail_msg
36
37


38
39
40
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
                                                                    float]],
                                                          SampleLogprobs]]]
41
42


43
44
45
46
47
48
def check_logprobs_close(
    *,
    outputs_0_lst: Sequence[TokensTextLogprobs],
    outputs_1_lst: Sequence[TokensTextLogprobs],
    name_0: str,
    name_1: str,
49
    num_outputs_0_skip_tokens: int = 0,
50
51
    warn_on_mismatch: bool = True,
):
52
53
    """
    Compare the logprobs of two sequences generated by different models,
54
    which should be similar but not necessarily equal.
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    Arguments:

    * outputs_0_lst: First sequence to compare
    * outputs_0_lst: Second sequence to compare
    * name_0: sequence #0 name
    * name_1: sequence #1 name
    * num_outputs_0_skip_tokens: If > 0, specifies the number of initial
                                 sequence #0 tokens & logprobs to discard
                                 before comparison, i.e. all
                                 of sequence #1 will be compared to
                                 sequence #0 beginning at index
                                 num_outputs_0_skip_tokens
    * warn_on_mismatch: Issue a warning if there is token-wise or text-wise
                        mismatch between the two sequences
70
    """
71
72
    assert len(outputs_0_lst) == len(outputs_1_lst)

73
74
75
76
77
78
79
    # Loop through responses to each prompt.
    for prompt_idx, (outputs_0,
                     outputs_1) in enumerate(zip(outputs_0_lst,
                                                 outputs_1_lst)):
        output_ids_0, output_str_0, logprobs_0 = outputs_0
        output_ids_1, output_str_1, logprobs_1 = outputs_1

80
81
82
83
84
        if logprobs_0 is None:
            logprobs_0 = [None] * len(output_ids_0)
        if logprobs_1 is None:
            logprobs_1 = [None] * len(output_ids_1)

85
86
87
88
89
90
91
92
93
        # Skip specified number of initial sequence #0 tokens
        # & logprobs, leaving output text as-is for simplicity
        # (text mismatches may generate warnings but do not
        # cause the test to fail.)
        if num_outputs_0_skip_tokens < 0:
            raise ValueError("num_outputs_0_skip_tokens must be non-negative")
        output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
        logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]

94
95
96
97
98
99
        # Loop through generated tokens.
        for idx, (output_id_0,
                  output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):

            # If generated tokens don't match, then
            if output_id_0 != output_id_1:
100
101
102
                logprobs_elem_0 = logprobs_0[idx]
                logprobs_elem_1 = logprobs_1[idx]

103
                # Each predicted token must be in top N logprobs of the other
104
                fail_msg = (
105
                    f"Test{prompt_idx}:"
106
                    f"\nMatched tokens:\t{output_ids_0[:idx]}"
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
                    f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
                    f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")

                assert logprobs_elem_0 is not None, fail_msg
                assert logprobs_elem_1 is not None, fail_msg
                assert output_id_0 in logprobs_elem_1, fail_msg
                assert output_id_1 in logprobs_elem_0, fail_msg

                if warn_on_mismatch:
                    with warnings.catch_warnings():
                        # This ensures that repeated warnings are shown
                        # in the output, not just the first occurrence
                        warnings.simplefilter("always")

                        warnings.warn(fail_msg, stacklevel=2)
122
123
124

                # Break out since sequences will now diverge.
                break
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        else:
            if output_str_0 != output_str_1 and warn_on_mismatch:
                # The token outputs exactly match,
                # so the text outputs should exactly match as well
                fail_msg = (f"Test{prompt_idx}:"
                            f"\n{name_0}:\t{output_str_0!r}"
                            f"\n{name_1}:\t{output_str_1!r}")

                with warnings.catch_warnings():
                    # This ensures that repeated warnings are shown
                    # in the output, not just the first occurrence
                    warnings.simplefilter("always")

                    warnings.warn(fail_msg, stacklevel=2)
139
140
141
142
143
144
145
146
147
148


class DecoderPromptType(Enum):
    '''
    For encoder/decoder models only -
    
    '''
    CUSTOM = 1
    NONE = 2
    EMPTY_STR = 3