utils.py 5.92 KB
Newer Older
1
2
3
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

4
from vllm.sequence import Logprob, SampleLogprobs
5
6
7
8

TokensText = Tuple[List[int], str]


9
10
11
12
13
14
15
def check_outputs_equal(
    *,
    outputs_0_lst: Sequence[TokensText],
    outputs_1_lst: Sequence[TokensText],
    name_0: str,
    name_1: str,
):
16
17
18
19
20
21
22
23
24
25
26
27
    """
    Compare the two sequences generated by different models, 
    which should be equal.
    """
    assert len(outputs_0_lst) == len(outputs_1_lst)

    for prompt_idx, (outputs_0,
                     outputs_1) in enumerate(zip(outputs_0_lst,
                                                 outputs_1_lst)):
        output_ids_0, output_str_0 = outputs_0
        output_ids_1, output_str_1 = outputs_1

28
29
30
31
32
33
34
        # The text and token outputs should exactly match
        fail_msg = (f"Test{prompt_idx}:"
                    f"\n{name_0}:\t{output_str_0!r}"
                    f"\n{name_1}:\t{output_str_1!r}")

        assert output_str_0 == output_str_1, fail_msg
        assert output_ids_0 == output_ids_1, fail_msg
35
36


37
38
39
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
                                                                    float]],
                                                          SampleLogprobs]]]
40

41
42
43
44
45
# Allow for tokens to be represented as str's rather than IDs
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
                                                        List[Dict[str,
                                                                  Logprob]]]]]

46

47
48
def check_logprobs_close(
    *,
49
50
    outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
    outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
51
52
    name_0: str,
    name_1: str,
53
    num_outputs_0_skip_tokens: int = 0,
54
    warn_on_mismatch: bool = True,
55
56
57
    always_check_logprobs: bool = False,
) -> None:
    """Compare the logprobs of two sequences generated by different models,
58
    which should be similar but not necessarily equal.
59

60
61
62
63
64
65
    Args:
      outputs_0_lst: First sequence to compare
      outputs_0_lst: Second sequence to compare
      name_0: sequence #0 name
      name_1: sequence #1 name
      num_outputs_0_skip_tokens: If > 0, specifies the number of initial
66
67
68
69
70
                                 sequence #0 tokens & logprobs to discard
                                 before comparison, i.e. all
                                 of sequence #1 will be compared to
                                 sequence #0 beginning at index
                                 num_outputs_0_skip_tokens
71
      warn_on_mismatch: Issue a warning if there is token-wise or text-wise
72
                        mismatch between the two sequences
73
      always_check_logprobs: If true, check logprobs even when tokens match
74
    """
75
76
    assert len(outputs_0_lst) == len(outputs_1_lst)

77
78
79
80
81
82
83
    # Loop through responses to each prompt.
    for prompt_idx, (outputs_0,
                     outputs_1) in enumerate(zip(outputs_0_lst,
                                                 outputs_1_lst)):
        output_ids_0, output_str_0, logprobs_0 = outputs_0
        output_ids_1, output_str_1, logprobs_1 = outputs_1

84
85
86
87
88
        if logprobs_0 is None:
            logprobs_0 = [None] * len(output_ids_0)
        if logprobs_1 is None:
            logprobs_1 = [None] * len(output_ids_1)

89
90
91
92
93
94
95
96
97
        # Skip specified number of initial sequence #0 tokens
        # & logprobs, leaving output text as-is for simplicity
        # (text mismatches may generate warnings but do not
        # cause the test to fail.)
        if num_outputs_0_skip_tokens < 0:
            raise ValueError("num_outputs_0_skip_tokens must be non-negative")
        output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
        logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]

98
99
100
101
        # Loop through generated tokens.
        for idx, (output_id_0,
                  output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):

102
103
104
105
106
107
            is_tok_mismatch = output_id_0 != output_id_1

            # If generated tokens don't match
            # or it is desired to always check logprobs,
            # then
            if is_tok_mismatch or always_check_logprobs:
108
109
110
                logprobs_elem_0 = logprobs_0[idx]
                logprobs_elem_1 = logprobs_1[idx]

111
                # Each predicted token must be in top N logprobs of the other
112
                fail_msg = (
113
                    f"Test{prompt_idx}:"
114
                    f"\nMatched tokens:\t{output_ids_0[:idx]}"
115
116
117
118
119
120
121
122
                    f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
                    f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")

                assert logprobs_elem_0 is not None, fail_msg
                assert logprobs_elem_1 is not None, fail_msg
                assert output_id_0 in logprobs_elem_1, fail_msg
                assert output_id_1 in logprobs_elem_0, fail_msg

123
                if warn_on_mismatch and is_tok_mismatch:
124
125
126
127
128
129
                    with warnings.catch_warnings():
                        # This ensures that repeated warnings are shown
                        # in the output, not just the first occurrence
                        warnings.simplefilter("always")

                        warnings.warn(fail_msg, stacklevel=2)
130
131
132

                # Break out since sequences will now diverge.
                break
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        else:
            if output_str_0 != output_str_1 and warn_on_mismatch:
                # The token outputs exactly match,
                # so the text outputs should exactly match as well
                fail_msg = (f"Test{prompt_idx}:"
                            f"\n{name_0}:\t{output_str_0!r}"
                            f"\n{name_1}:\t{output_str_1!r}")

                with warnings.catch_warnings():
                    # This ensures that repeated warnings are shown
                    # in the output, not just the first occurrence
                    warnings.simplefilter("always")

                    warnings.warn(fail_msg, stacklevel=2)