test_generation_models.py 6.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
"""
Usage:

To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""

import dataclasses
23
import multiprocessing as mp
24
import os
25
import unittest
26
from typing import List
27
28
29
30

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
31
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
32

33

34
35
36
37
38
39
40
@dataclasses.dataclass
class ModelCase:
    model_path: str
    tp_size: int = 1
    prefill_tolerance: float = 5e-2
    decode_tolerance: float = 5e-2
    rouge_l_tolerance: float = 1
41
    skip_long_prompt: bool = False
42
43


44
# Popular models that run on the CI
45
CI_MODELS = [
46
    ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
47
    ModelCase("google/gemma-2-2b"),
48
]
49

50
# All other models that do not run on the CI
51
52
ALL_OTHER_MODELS = [
    ModelCase("Qwen/Qwen2-1.5B"),
53
    ModelCase("Qwen/Qwen2.5-14B-Instruct"),
54
55
    ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
    ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
56
    ModelCase("THUDM/glm-4-9b-chat"),
Chayenne's avatar
Chayenne committed
57
    ModelCase("openai-community/gpt2"),
Tanjiro's avatar
Tanjiro committed
58
    ModelCase("microsoft/Phi-3-small-8k-instruct"),
59
]
60

61
TORCH_DTYPES = [torch.float16]
62
63


64
class TestGenerationModels(unittest.TestCase):
65

66
67
    @classmethod
    def setUpClass(cls):
68
        mp.set_start_method("spawn", force=True)
69

70
    def assert_close_logits_and_output_strs(
71
        self,
72
73
74
        prompts: List[str],
        model_case: ModelCase,
        torch_dtype: torch.dtype,
75
    ) -> None:
76
77
78
79
80
81
82
        model_path = model_case.model_path
        prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
            model_case.prefill_tolerance,
            model_case.decode_tolerance,
            model_case.rouge_l_tolerance,
        )
        max_new_tokens = 32
83

84
        with HFRunner(
85
86
87
            model_path,
            torch_dtype=torch_dtype,
            model_type="generation",
88
        ) as hf_runner:
89
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
90
91
92

        with SRTRunner(
            model_path,
93
            tp_size=model_case.tp_size,
94
            torch_dtype=torch_dtype,
95
            model_type="generation",
96
        ) as srt_runner:
97
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
98
99

        for i in range(len(prompts)):
100
            # Compare input logprobs
101
102
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
103
104
105
106
107
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
108
109
110
111
112
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
113

114
            # Compare output logprobs
115
116
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
117

118
            print(
119
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
120
121
            )
            if input_len <= 100:
122
123
124
125
126
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
127

128
129
130
        # Compare output strings
        print(f"{hf_outputs.output_strs=}")
        print(f"{srt_outputs.output_strs=}")
131
132
133
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
134
        print(f"{rouge_l_scores=}")
135
        assert all(
136
137
138
139
140
            score >= rouge_l_tolerance for score in rouge_l_scores
        ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    def test_ci_models(self):
        for model_case in CI_MODELS:
141
            for torch_dtype in TORCH_DTYPES:
142
143
144
145
146
147
148

                # Skip long prompts for models that do not have a long context
                prompts = DEFAULT_PROMPTS
                if model_case.skip_long_prompt:
                    prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

                # Assert the logits and output strs are close
149
                self.assert_close_logits_and_output_strs(
150
                    prompts, model_case, torch_dtype
151
152
                )

153
    def test_others(self):
154
155
156
        if is_in_ci():
            return

157
        for model_case in ALL_OTHER_MODELS:
158
            # Only run a specified model
159
160
161
162
163
            if (
                "ONLY_RUN" in os.environ
                and os.environ["ONLY_RUN"] != model_case.model_path
            ):
                continue
164

165
            # Skip long prompts for models that do not have a long context
166
            prompts = DEFAULT_PROMPTS
167
            if model_case.skip_long_prompt:
168
169
170
171
                prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]

            # Assert the logits and output strs are close
            self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
172

173

174
if __name__ == "__main__":
Mingyi's avatar
Mingyi committed
175
    unittest.main()