test_generation_models.py 5.89 KB
Newer Older
1
2
3
4
5
6
7
8
"""
Usage:

To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""

9
10
11
12
13
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
14

15
    http://www.apache.org/licenses/LICENSE-2.0
16

17
18
19
20
21
22
23
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

24
import dataclasses
25
import multiprocessing as mp
26
import os
27
import unittest
28
from typing import List
29
30
31
32

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
33
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
34

35

36
37
38
39
40
41
42
@dataclasses.dataclass
class ModelCase:
    model_path: str
    tp_size: int = 1
    prefill_tolerance: float = 5e-2
    decode_tolerance: float = 5e-2
    rouge_l_tolerance: float = 1
43
    skip_long_prompt: bool = False
44
45


46
# Popular models that run on the CI
47
CI_MODELS = [
48
    ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
49
    ModelCase("google/gemma-2-2b"),
50
]
51

52
# All other models that do not run on the CI
53
54
ALL_OTHER_MODELS = [
    ModelCase("Qwen/Qwen2-1.5B"),
55
    ModelCase("Qwen/Qwen2.5-14B-Instruct"),
56
57
    ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
    ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
58
    ModelCase("THUDM/glm-4-9b-chat"),
DanielC12321's avatar
DanielC12321 committed
59
    ModelCase("openai-community/gpt2")
60
]
61

62
TORCH_DTYPES = [torch.float16]
63
64


65
class TestGenerationModels(unittest.TestCase):
66

67
68
    @classmethod
    def setUpClass(cls):
69
        mp.set_start_method("spawn", force=True)
70

71
    def assert_close_logits_and_output_strs(
72
        self,
73
74
75
        prompts: List[str],
        model_case: ModelCase,
        torch_dtype: torch.dtype,
76
    ) -> None:
77
78
79
80
81
82
83
        model_path = model_case.model_path
        prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
            model_case.prefill_tolerance,
            model_case.decode_tolerance,
            model_case.rouge_l_tolerance,
        )
        max_new_tokens = 32
84

85
        with HFRunner(
86
87
88
            model_path,
            torch_dtype=torch_dtype,
            model_type="generation",
89
        ) as hf_runner:
90
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
91
92
93

        with SRTRunner(
            model_path,
94
            tp_size=model_case.tp_size,
95
            torch_dtype=torch_dtype,
96
            model_type="generation",
97
        ) as srt_runner:
98
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
99
100

        for i in range(len(prompts)):
101
            # Compare input logprobs
102
103
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
104
105
106
107
108
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
109
110
111
112
113
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
114

115
            # Compare output logprobs
116
117
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
118

119
            print(
120
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
121
122
            )
            if input_len <= 100:
123
124
125
126
127
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
128

129
130
131
        # Compare output strings
        print(f"{hf_outputs.output_strs=}")
        print(f"{srt_outputs.output_strs=}")
132
133
134
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
135
        print(f"{rouge_l_scores=}")
136
        assert all(
137
138
139
140
141
            score >= rouge_l_tolerance for score in rouge_l_scores
        ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    def test_ci_models(self):
        for model_case in CI_MODELS:
142
            for torch_dtype in TORCH_DTYPES:
143
144
145
146
147
148
149

                # Skip long prompts for models that do not have a long context
                prompts = DEFAULT_PROMPTS
                if model_case.skip_long_prompt:
                    prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

                # Assert the logits and output strs are close
150
                self.assert_close_logits_and_output_strs(
151
                    prompts, model_case, torch_dtype
152
153
                )

154
    def test_others(self):
155
156
157
        if is_in_ci():
            return

158
        for model_case in ALL_OTHER_MODELS:
159
            # Only run a specified model
160
161
162
163
164
            if (
                "ONLY_RUN" in os.environ
                and os.environ["ONLY_RUN"] != model_case.model_path
            ):
                continue
165

166
            # Skip long prompts for models that do not have a long context
167
            prompts = DEFAULT_PROMPTS
168
            if model_case.skip_long_prompt:
169
170
171
172
                prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

            # Assert the logits and output strs are close
            self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
173

174

175
if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
176
    unittest.main()