test_generation_models.py 4.29 KB
Newer Older
1
2
3
4
5
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner

MODELS = [
23
24
25
    ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1),
    ("google/gemma-2-2b", 1, 3, 3e-2, 1),
    ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1),
26
27
28
29
]
TORCH_DTYPES = [torch.float16]


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def lcs(X, Y):
    m = len(X)
    n = len(Y)
    L = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 or j == 0:
                L[i][j] = 0
            elif X[i - 1] == Y[j - 1]:
                L[i][j] = L[i - 1][j - 1] + 1
            else:
                L[i][j] = max(L[i - 1][j], L[i][j - 1])

    return L[m][n]


def calculate_rouge_l(output_strs_list1, output_strs_list2):
    rouge_l_scores = []

    for s1, s2 in zip(output_strs_list1, output_strs_list2):
        lcs_len = lcs(s1, s2)
        precision = lcs_len / len(s1) if len(s1) > 0 else 0
        recall = lcs_len / len(s2) if len(s2) > 0 else 0
        if precision + recall > 0:
            fmeasure = (2 * precision * recall) / (precision + recall)
        else:
            fmeasure = 0.0
        rouge_l_scores.append(fmeasure)

    return rouge_l_scores


63
class TestGenerationModels(unittest.TestCase):
64

65
    def assert_close_prefill_logits_and_output_strs(
66
67
68
69
70
        self,
        prompts,
        model_path,
        tp_size,
        torch_dtype,
71
        max_new_tokens,
72
73
        prefill_tolerance,
        rouge_threshold,
74
        long_context_tolerance,
75
    ) -> None:
76
77
        if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
            prompts = prompts[:-1]
78
        with HFRunner(
79
            model_path, torch_dtype=torch_dtype, is_generation=True
80
        ) as hf_runner:
81
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
82
83
84
85
86

        with SRTRunner(
            model_path,
            tp_size=tp_size,
            torch_dtype=torch_dtype,
87
            is_generation=True,
88
        ) as srt_runner:
89
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
90
91
92
93
94

        for i in range(len(prompts)):
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])

95
96
97
            print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs)))
            if hf_logprobs.shape[0] <= 100:
                assert torch.all(
98
                    abs(hf_logprobs - srt_logprobs) < prefill_tolerance
99
                ), "prefill logprobs are not all close"
100

101
102
        print(hf_outputs.output_strs)
        print(srt_outputs.output_strs)
103
104
105
106
107
108
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
        assert all(
            score >= rouge_threshold for score in rouge_l_scores
        ), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
109

110
    def test_prefill_logits_and_output_strs(self):
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        import multiprocessing as mp

        try:
            mp.set_start_method("spawn")
        except RuntimeError:
            pass

        for (
            model,
            tp_size,
            long_context_tolerance,
            prefill_tolerance,
            rouge_threshold,
        ) in MODELS:
125
            for torch_dtype in TORCH_DTYPES:
126
127
128
129
130
131
132
                max_new_tokens = 8
                self.assert_close_prefill_logits_and_output_strs(
                    DEFAULT_PROMPTS,
                    model,
                    tp_size,
                    torch_dtype,
                    max_new_tokens,
133
134
                    prefill_tolerance=prefill_tolerance,
                    rouge_threshold=rouge_threshold,
135
                    long_context_tolerance=long_context_tolerance,
136
137
138
139
140
                )


if __name__ == "__main__":
    unittest.main(warnings="ignore")