test_generation_models.py 6.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""
Usage:

Kiv Chen's avatar
Kiv Chen committed
17
18
19
To test a specific model locally:
1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")`
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels`
20
21
22
"""

import dataclasses
23
import multiprocessing as mp
24
import os
Kiv Chen's avatar
Kiv Chen committed
25
import random
26
import unittest
27
from typing import List
28
29
30

import torch

31
32
33
34
35
36
from sglang.test.runners import (
    DEFAULT_PROMPTS,
    HFRunner,
    SRTRunner,
    check_close_model_outputs,
)
37
from sglang.test.test_utils import CustomTestCase, is_in_ci
38

39

40
41
42
43
44
@dataclasses.dataclass
class ModelCase:
    model_path: str
    tp_size: int = 1
    prefill_tolerance: float = 5e-2
45
    decode_tolerance: float = 6e-2  # Increased to fix numerical error in issue #8614.
46
    rouge_l_tolerance: float = 1
47
    skip_long_prompt: bool = False
48
    trust_remote_code: bool = False
49
50


51
# Popular models that run on the CI
52
CI_MODELS = [
53
    ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
54
    ModelCase("google/gemma-2-2b"),
55
]
56

Kiv Chen's avatar
Kiv Chen committed
57
58
59
# the complete set of models to test sglang's generation model
ALL_MODELS = [
    *CI_MODELS,
60
    ModelCase("Qwen/Qwen2-1.5B"),
61
    ModelCase("Qwen/Qwen2.5-14B-Instruct"),
62
63
    ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
    ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
64
    ModelCase("shanearora/2025-sep-a-base-model"),
65
66
67
    ModelCase(
        "THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
    ),
Chayenne's avatar
Chayenne committed
68
    ModelCase("openai-community/gpt2"),
69
    ModelCase("microsoft/phi-1_5", trust_remote_code=True),
70
    ModelCase("adept/persimmon-8b-chat"),
71
    ModelCase("upstage/SOLAR-10.7B-Instruct-v1.0"),
72
    ModelCase("inclusionAI/Ling-lite", trust_remote_code=True),
Kiv Chen's avatar
Kiv Chen committed
73
    ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
Jani Monoses's avatar
Jani Monoses committed
74
    ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
75
    ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
76
77
78
79
80
81
    ModelCase(
        "microsoft/Phi-3.5-MoE-instruct",
        tp_size=2,
        trust_remote_code=True,
        skip_long_prompt=True,
    ),
wenhuipeng's avatar
wenhuipeng committed
82
    ModelCase("facebook/opt-125m", skip_long_prompt=True),
83
84
85
86
87
88
    ModelCase(
        "nvidia/Llama-3_3-Nemotron-Super-49B-v1_5",
        tp_size=2,
        trust_remote_code=True,
        skip_long_prompt=True,
    ),
89
90
91
92
93
94
    ModelCase(
        "nvidia/Llama-3_1-Nemotron-Ultra-253B-v1",
        tp_size=8,
        trust_remote_code=True,
        skip_long_prompt=True,
    ),
95
96
97
98
99
    ModelCase(
        "nvidia/NVIDIA-Nemotron-Nano-9B-v2",
        trust_remote_code=True,
        skip_long_prompt=True,
    ),
EduardDurech's avatar
EduardDurech committed
100
101
102
103
104
    ModelCase(
        "swiss-ai/Apertus-8B",
        trust_remote_code=True,
        skip_long_prompt=True,
    ),
105
]
106

107
TORCH_DTYPES = [torch.float16]
108
109


110
class TestGenerationModels(CustomTestCase):
111

112
113
    @classmethod
    def setUpClass(cls):
114
        mp.set_start_method("spawn", force=True)
115

116
    def assert_close_logits_and_output_strs(
117
        self,
118
119
120
        prompts: List[str],
        model_case: ModelCase,
        torch_dtype: torch.dtype,
121
    ) -> None:
122
123
124
125
126
127
128
        model_path = model_case.model_path
        prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
            model_case.prefill_tolerance,
            model_case.decode_tolerance,
            model_case.rouge_l_tolerance,
        )
        max_new_tokens = 32
129

130
        with HFRunner(
131
132
133
            model_path,
            torch_dtype=torch_dtype,
            model_type="generation",
134
            trust_remote_code=model_case.trust_remote_code,
135
        ) as hf_runner:
136
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
137
138
139

        with SRTRunner(
            model_path,
140
            tp_size=model_case.tp_size,
141
            torch_dtype=torch_dtype,
142
            model_type="generation",
143
            trust_remote_code=model_case.trust_remote_code,
144
        ) as srt_runner:
145
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
146

147
148
149
150
151
152
153
        check_close_model_outputs(
            hf_outputs=hf_outputs,
            srt_outputs=srt_outputs,
            prefill_tolerance=model_case.prefill_tolerance,
            decode_tolerance=model_case.decode_tolerance,
            rouge_l_tolerance=model_case.rouge_l_tolerance,
            debug_text=f"model_path={model_path} prompts={prompts}",
154
        )
155

Kiv Chen's avatar
Kiv Chen committed
156
    @unittest.skipIf(not is_in_ci(), "Local test should run all models")
157
158
    def test_ci_models(self):
        for model_case in CI_MODELS:
159
            for torch_dtype in TORCH_DTYPES:
Kiv Chen's avatar
Kiv Chen committed
160
                prompts = DEFAULT_PROMPTS
161
162
163
164
165
166

                # Skip long prompts for models that do not have a long context
                if model_case.skip_long_prompt:
                    prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

                # Assert the logits and output strs are close
167
                self.assert_close_logits_and_output_strs(
168
                    prompts, model_case, torch_dtype
169
170
                )

Kiv Chen's avatar
Kiv Chen committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    @unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity")
    def test_all_models(self):
        for model_case in ALL_MODELS:
            for torch_dtype in TORCH_DTYPES:
                if (
                    "ONLY_RUN" in os.environ
                    and os.environ["ONLY_RUN"] != model_case.model_path
                ):
                    continue

                # Skip long prompts for models that do not have a long context
                prompts = DEFAULT_PROMPTS
                if model_case.skip_long_prompt:
                    prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

                # Assert the logits and output strs are close
                self.assert_close_logits_and_output_strs(
                    prompts, model_case, torch_dtype
                )
190

191

192
if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
193
    unittest.main()