utils.py 2.54 KB
Newer Older
1
2
3
4
5
6
7
8
from typing import Dict, List, Tuple

TokensText = Tuple[List[int], str]


def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText],
                        name_0: str, name_1: str):
    """
9
    Compare the two sequences generated by different models,
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    which should be equal.
    """
    assert len(outputs_0_lst) == len(outputs_1_lst)

    for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
        output_ids_0, output_str_0 = outputs_0
        output_ids_1, output_str_1 = outputs_1

        assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
                                              f"\n{name_0}:\t{output_str_0!r}"
                                              f"\n{name_1}:\t{output_str_1!r}")
        assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
                                              f"\n{name_0}:\t{output_str_0!r}"
                                              f"\n{name_1}:\t{output_str_1!r}")


TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]


def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
                         outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str):
    """
    Compare the logprobs of two sequences generated by different models,
    which should be similar but not necessarily equal.
    """
    assert len(outputs_0_lst) == len(outputs_1_lst)

    # Loop through responses to each prompt.
    for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)):
        output_ids_0, output_str_0, logprobs_0 = outputs_0
        output_ids_1, output_str_1, logprobs_1 = outputs_1

        # Loop through generated tokens.
        for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):

            # If generated tokens don't match, then
            if output_id_0 != output_id_1:
                # Each predicted token must be in top N logprobs of the other
                assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:"
                                                        f"\n{name_0}:\t{output_str_0!r}"
                                                        f"\n{name_1}:\t{output_str_1!r}")
                assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:"
                                                        f"\n{name_0}:\t{output_str_0!r}"
                                                        f"\n{name_1}:\t{output_str_1!r}")

                # Break out since sequences will now diverge.
                break