test_generation_models.py 4.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
"""
Usage:

To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""

import dataclasses
23
import multiprocessing as mp
24
import os
25
import unittest
26
from typing import List
27
28
29

import torch

30
31
32
33
34
35
36
from sglang.test.runners import (
    DEFAULT_PROMPTS,
    HFRunner,
    SRTRunner,
    check_close_model_outputs,
)
from sglang.test.test_utils import is_in_ci
37

38

39
40
41
42
43
44
45
@dataclasses.dataclass
class ModelCase:
    model_path: str
    tp_size: int = 1
    prefill_tolerance: float = 5e-2
    decode_tolerance: float = 5e-2
    rouge_l_tolerance: float = 1
46
    skip_long_prompt: bool = False
47
    trust_remote_code: bool = False
48
49


50
# Popular models that run on the CI
51
CI_MODELS = [
52
    ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
53
    ModelCase("google/gemma-2-2b"),
54
]
55

56
# All other models that do not run on the CI
57
58
ALL_OTHER_MODELS = [
    ModelCase("Qwen/Qwen2-1.5B"),
59
    ModelCase("Qwen/Qwen2.5-14B-Instruct"),
60
61
    ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
    ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
62
63
64
    ModelCase(
        "THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
    ),
Chayenne's avatar
Chayenne committed
65
    ModelCase("openai-community/gpt2"),
Tanjiro's avatar
Tanjiro committed
66
    ModelCase("microsoft/Phi-3-small-8k-instruct"),
Jani Monoses's avatar
Jani Monoses committed
67
    ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
68
    ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
69
]
70

71
TORCH_DTYPES = [torch.float16]
72
73


74
class TestGenerationModels(unittest.TestCase):
75

76
77
    @classmethod
    def setUpClass(cls):
78
        mp.set_start_method("spawn", force=True)
79

80
    def assert_close_logits_and_output_strs(
81
        self,
82
83
84
        prompts: List[str],
        model_case: ModelCase,
        torch_dtype: torch.dtype,
85
    ) -> None:
86
87
88
89
90
91
92
        model_path = model_case.model_path
        prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
            model_case.prefill_tolerance,
            model_case.decode_tolerance,
            model_case.rouge_l_tolerance,
        )
        max_new_tokens = 32
93

94
        with HFRunner(
95
96
97
            model_path,
            torch_dtype=torch_dtype,
            model_type="generation",
98
            trust_remote_code=model_case.trust_remote_code,
99
        ) as hf_runner:
100
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
101
102
103

        with SRTRunner(
            model_path,
104
            tp_size=model_case.tp_size,
105
            torch_dtype=torch_dtype,
106
            model_type="generation",
107
            trust_remote_code=model_case.trust_remote_code,
108
        ) as srt_runner:
109
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
110

111
112
113
114
115
116
117
        check_close_model_outputs(
            hf_outputs=hf_outputs,
            srt_outputs=srt_outputs,
            prefill_tolerance=model_case.prefill_tolerance,
            decode_tolerance=model_case.decode_tolerance,
            rouge_l_tolerance=model_case.rouge_l_tolerance,
            debug_text=f"model_path={model_path} prompts={prompts}",
118
        )
119
120
121

    def test_ci_models(self):
        for model_case in CI_MODELS:
122
            for torch_dtype in TORCH_DTYPES:
123
124
125
126
127
128
129

                # Skip long prompts for models that do not have a long context
                prompts = DEFAULT_PROMPTS
                if model_case.skip_long_prompt:
                    prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

                # Assert the logits and output strs are close
130
                self.assert_close_logits_and_output_strs(
131
                    prompts, model_case, torch_dtype
132
133
                )

134
    def test_others(self):
135
136
137
        if is_in_ci():
            return

138
        for model_case in ALL_OTHER_MODELS:
139
            # Only run a specified model
140
141
142
143
144
            if (
                "ONLY_RUN" in os.environ
                and os.environ["ONLY_RUN"] != model_case.model_path
            ):
                continue
145

146
            # Skip long prompts for models that do not have a long context
147
            prompts = DEFAULT_PROMPTS
148
            if model_case.skip_long_prompt:
149
150
151
152
                prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

            # Assert the logits and output strs are close
            self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
153

154

155
if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
156
    unittest.main()