test_generation_models.py 5.56 KB
Newer Older
1
2
3
4
5
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9
10
11
12
13
14
15
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
import multiprocessing as mp
17
18
19
20
21
22
23
import unittest

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner

MODELS = [
24
25
26
    ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1),
    ("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1),
    ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1),
27
28
29
30
]
TORCH_DTYPES = [torch.float16]


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def lcs(X, Y):
    m = len(X)
    n = len(Y)
    L = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    return L[m][n]


def calculate_rouge_l(output_strs_list1, output_strs_list2):
    rouge_l_scores = []

    for s1, s2 in zip(output_strs_list1, output_strs_list2):
        lcs_len = lcs(s1, s2)
        precision = lcs_len / len(s1) if len(s1) > 0 else 0
        recall = lcs_len / len(s2) if len(s2) > 0 else 0
        if precision + recall > 0:
            fmeasure = (2 * precision * recall) / (precision + recall)
        else:
            fmeasure = 0.0
        rouge_l_scores.append(fmeasure)

    return rouge_l_scores


64
65
class TestGenerationModels(unittest.TestCase):
    def assert_close_prefill_logits_and_output_strs(
66
67
68
69
70
        self,
        prompts,
        model_path,
        tp_size,
        torch_dtype,
71
        max_new_tokens,
72
        prefill_tolerance,
73
        output_tolerance,
74
        rouge_threshold,
75
        long_context_tolerance,
76
    ) -> None:
77
78
        if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
            prompts = prompts[:-1]
79
        with HFRunner(
80
            model_path, torch_dtype=torch_dtype, is_generation=True
81
        ) as hf_runner:
82
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
83
84
85
86
87

        with SRTRunner(
            model_path,
            tp_size=tp_size,
            torch_dtype=torch_dtype,
88
            is_generation=True,
89
        ) as srt_runner:
90
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
91
92

        for i in range(len(prompts)):
93
            # input logprobs comparison
94
95
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
96
97
98
99
100
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
101
                assert torch.all(
102
                    abs(hf_logprobs - srt_logprobs) < prefill_tolerance
103
                ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}"
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            # output logprobs comparison
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
            # print(
            #     "output logprobs diff",
            #     [
            #         float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
            #         for j in range(max_new_tokens)
            #     ],
            # )
            print(
                "output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
                assert torch.all(
                    abs(hf_logprobs - srt_logprobs) < output_tolerance
                ), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}"

        # output strings comparison
124
125
        print(f"hf_outputs.output_strs={hf_outputs.output_strs}")
        print(f"srt_outputs.output_strs={srt_outputs.output_strs}")
126
127
128
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
129
        print(f"rouge_l_scores={rouge_l_scores}")
130
131
        assert all(
            score >= rouge_threshold for score in rouge_l_scores
132
        ), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}"
133

134
    def test_prefill_logits_and_output_strs(self):
135
136
137
138
139
        for (
            model,
            tp_size,
            long_context_tolerance,
            prefill_tolerance,
140
            output_tolerance,
141
142
            rouge_threshold,
        ) in MODELS:
143
            for torch_dtype in TORCH_DTYPES:
144
                max_new_tokens = 32
145
146
147
148
149
150
                self.assert_close_prefill_logits_and_output_strs(
                    DEFAULT_PROMPTS,
                    model,
                    tp_size,
                    torch_dtype,
                    max_new_tokens,
151
                    prefill_tolerance=prefill_tolerance,
152
                    output_tolerance=output_tolerance,
153
                    rouge_threshold=rouge_threshold,
154
                    long_context_tolerance=long_context_tolerance,
155
156
157
158
                )


if __name__ == "__main__":
159
160
161
162
163
    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass

Mingyi's avatar
Mingyi committed
164
    unittest.main()