test_generation_models.py 6.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
"""
Usage:

To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""

import dataclasses
23
import multiprocessing as mp
24
import os
25
import unittest
26
from typing import List
27
28
29
30

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
31
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
32

33

34
35
36
37
38
39
40
@dataclasses.dataclass
class ModelCase:
    model_path: str
    tp_size: int = 1
    prefill_tolerance: float = 5e-2
    decode_tolerance: float = 5e-2
    rouge_l_tolerance: float = 1
41
    skip_long_prompt: bool = False
42
43


44
# Popular models that run on the CI
45
CI_MODELS = [
46
    ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
47
    ModelCase("google/gemma-2-2b"),
48
]
49

50
# All other models that do not run on the CI
51
52
ALL_OTHER_MODELS = [
    ModelCase("Qwen/Qwen2-1.5B"),
53
    ModelCase("Qwen/Qwen2.5-14B-Instruct"),
54
55
    ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
    ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
56
    ModelCase("THUDM/glm-4-9b-chat"),
Chayenne's avatar
Chayenne committed
57
    ModelCase("openai-community/gpt2"),
Tanjiro's avatar
Tanjiro committed
58
    ModelCase("microsoft/Phi-3-small-8k-instruct"),
Jani Monoses's avatar
Jani Monoses committed
59
    ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
60
]
61

62
TORCH_DTYPES = [torch.float16]
63
64


65
class TestGenerationModels(unittest.TestCase):
66

67
68
    @classmethod
    def setUpClass(cls):
69
        mp.set_start_method("spawn", force=True)
70

71
    def assert_close_logits_and_output_strs(
72
        self,
73
74
75
        prompts: List[str],
        model_case: ModelCase,
        torch_dtype: torch.dtype,
76
    ) -> None:
77
78
79
80
81
82
83
        model_path = model_case.model_path
        prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
            model_case.prefill_tolerance,
            model_case.decode_tolerance,
            model_case.rouge_l_tolerance,
        )
        max_new_tokens = 32
84

85
        with HFRunner(
86
87
88
            model_path,
            torch_dtype=torch_dtype,
            model_type="generation",
89
        ) as hf_runner:
90
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
91
92
93

        with SRTRunner(
            model_path,
94
            tp_size=model_case.tp_size,
95
            torch_dtype=torch_dtype,
96
            model_type="generation",
97
        ) as srt_runner:
98
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
99
100

        for i in range(len(prompts)):
101
            # Compare input logprobs
102
103
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
104
105
106
107
108
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
109
110
111
112
113
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
114

115
            # Compare output logprobs
116
117
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
118

119
            print(
120
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
121
122
            )
            if input_len <= 100:
123
124
125
126
127
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
128

129
130
131
        # Compare output strings
        print(f"{hf_outputs.output_strs=}")
        print(f"{srt_outputs.output_strs=}")
132
133
134
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
135
        print(f"{rouge_l_scores=}")
136
        assert all(
137
138
139
140
141
            score >= rouge_l_tolerance for score in rouge_l_scores
        ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    def test_ci_models(self):
        for model_case in CI_MODELS:
142
            for torch_dtype in TORCH_DTYPES:
143
144
145
146
147
148
149

                # Skip long prompts for models that do not have a long context
                prompts = DEFAULT_PROMPTS
                if model_case.skip_long_prompt:
                    prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

                # Assert the logits and output strs are close
150
                self.assert_close_logits_and_output_strs(
151
                    prompts, model_case, torch_dtype
152
153
                )

154
    def test_others(self):
155
156
157
        if is_in_ci():
            return

158
        for model_case in ALL_OTHER_MODELS:
159
            # Only run a specified model
160
161
162
163
164
            if (
                "ONLY_RUN" in os.environ
                and os.environ["ONLY_RUN"] != model_case.model_path
            ):
                continue
165

166
            # Skip long prompts for models that do not have a long context
167
            prompts = DEFAULT_PROMPTS
168
            if model_case.skip_long_prompt:
169
170
171
172
                prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

            # Assert the logits and output strs are close
            self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
173

174

175
if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
176
    unittest.main()