test_generation_models.py 5.56 KB
Newer Older
1
2
3
4
5
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9
10
11
12
13
14
15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
import multiprocessing as mp
17
18
19
20
21
22
23
import unittest

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner

MODELS = [
24
25
26
    ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
    ("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
    ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
27
28
29
30
]
TORCH_DTYPES = [torch.float16]


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def lcs(X, Y):
    m = len(X)
    n = len(Y)
    L = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    return L[m][n]


def calculate_rouge_l(output_strs_list1, output_strs_list2):
    rouge_l_scores = []

    for s1, s2 in zip(output_strs_list1, output_strs_list2):
        lcs_len = lcs(s1, s2)
        precision = lcs_len / len(s1) if len(s1) > 0 else 0
        recall = lcs_len / len(s2) if len(s2) > 0 else 0
        if precision + recall > 0:
            fmeasure = (2 * precision * recall) / (precision + recall)
        else:
            fmeasure = 0.0
        rouge_l_scores.append(fmeasure)

    return rouge_l_scores


64
65
class TestGenerationModels(unittest.TestCase):
    def assert_close_prefill_logits_and_output_strs(
66
67
68
69
70
        self,
        prompts,
        model_path,
        tp_size,
        torch_dtype,
71
        max_new_tokens,
72
        prefill_tolerance,
73
        output_tolerance,
74
        rouge_threshold,
75
        long_context_tolerance,
76
    ) -> None:
77
78
        if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
            prompts = prompts[:-1]
79

80
        with HFRunner(
81
            model_path, torch_dtype=torch_dtype, is_generation=True
82
        ) as hf_runner:
83
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
84
85
86
87
88

        with SRTRunner(
            model_path,
            tp_size=tp_size,
            torch_dtype=torch_dtype,
89
            is_generation=True,
90
        ) as srt_runner:
91
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
92
93

        for i in range(len(prompts)):
94
            # input logprobs comparison
95
96
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
97
98
99
100
101
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
102
                assert torch.all(
103
                    abs(hf_logprobs - srt_logprobs) < prefill_tolerance
104
                ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            # output logprobs comparison
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
            # print(
            #     "output logprobs diff",
            #     [
            #         float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
            #         for j in range(max_new_tokens)
            #     ],
            # )
            print(
                "output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(
                    abs(hf_logprobs - srt_logprobs) < output_tolerance
                ), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"

        # output strings comparison
125
126
        print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
        print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
127
128
129
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
130
        print(f"rouge_l_scores={rouge_l_scores}")
131
132
        assert all(
            score >= rouge_threshold for score in rouge_l_scores
133
        ), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}"
134

135
    def test_prefill_logits_and_output_strs(self):
136
137
138
139
140
        for (
            model,
            tp_size,
            long_context_tolerance,
            prefill_tolerance,
141
            output_tolerance,
142
143
            rouge_threshold,
        ) in MODELS:
144
            for torch_dtype in TORCH_DTYPES:
145
                max_new_tokens = 32
146
147
148
149
150
151
                self.assert_close_prefill_logits_and_output_strs(
                    DEFAULT_PROMPTS,
                    model,
                    tp_size,
                    torch_dtype,
                    max_new_tokens,
152
                    prefill_tolerance=prefill_tolerance,
153
                    output_tolerance=output_tolerance,
154
                    rouge_threshold=rouge_threshold,
155
                    long_context_tolerance=long_context_tolerance,
156
157
158
159
                )


if __name__ == "__main__":
160
161
162
163
164
    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass

Mingyi's avatar
Mingyi committed
165
    unittest.main()