test_generation_models.py 4.97 KB
Newer Older
1
2
3
4
5
6
7
8
"""
Usage:

To test a specific model:
1. Add it to ALL_OTHER_MODELS
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
"""

9
10
11
12
13
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
14

15
    http://www.apache.org/licenses/LICENSE-2.0
16

17
18
19
20
21
22
23
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

24
import dataclasses
25
import multiprocessing as mp
26
import os
27
import unittest
28
from typing import List
29
30
31
32

import torch

from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
33
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
34

35

36
37
38
39
40
41
42
@dataclasses.dataclass
class ModelCase:
    model_path: str
    tp_size: int = 1
    prefill_tolerance: float = 5e-2
    decode_tolerance: float = 5e-2
    rouge_l_tolerance: float = 1
43
44


45
46
47
48
49
# Popular models that run on CI
CI_MODELS = [
    ModelCase("meta-llama/Meta-Llama-3.1-8B-Instruct"),
    ModelCase("google/gemma-2-2b"),
]
50

51
52
53
# All other models
ALL_OTHER_MODELS = [
    ModelCase("Qwen/Qwen2-1.5B"),
54
    ModelCase("HuggingFaceTB/SmolLM-135M-Instruct"),
55
]
56

57
TORCH_DTYPES = [torch.float16]
58
59


60
class TestGenerationModels(unittest.TestCase):
61
    def assert_close_logits_and_output_strs(
62
        self,
63
64
65
        prompts: List[str],
        model_case: ModelCase,
        torch_dtype: torch.dtype,
66
    ) -> None:
67
68
69
70
71
72
73
        model_path = model_case.model_path
        prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
            model_case.prefill_tolerance,
            model_case.decode_tolerance,
            model_case.rouge_l_tolerance,
        )
        max_new_tokens = 32
74

75
        with HFRunner(
76
77
78
            model_path,
            torch_dtype=torch_dtype,
            model_type="generation",
79
        ) as hf_runner:
80
            hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
81
82
83

        with SRTRunner(
            model_path,
84
            tp_size=model_case.tp_size,
85
            torch_dtype=torch_dtype,
86
            model_type="generation",
87
        ) as srt_runner:
88
            srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
89
90

        for i in range(len(prompts)):
91
            # Compare input logprobs
92
93
            hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
94
95
96
97
98
            input_len = hf_logprobs.shape[0]
            print(
                "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
            )
            if input_len <= 100:
99
100
101
102
103
                assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
                    f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"prefill_tolerance={prefill_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
104

105
            # Compare output logprobs
106
107
            hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
            srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
108

109
            print(
110
                "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
111
112
            )
            if input_len <= 100:
113
114
115
116
117
                assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
                    f"decode logprobs are not all close with model_path={model_path} prompts={prompts} "
                    f"decode_tolerance={decode_tolerance}."
                    f"{hf_logprobs=}, {srt_logprobs=}"
                )
118

119
120
121
        # Compare output strings
        print(f"{hf_outputs.output_strs=}")
        print(f"{srt_outputs.output_strs=}")
122
123
124
        rouge_l_scores = calculate_rouge_l(
            hf_outputs.output_strs, srt_outputs.output_strs
        )
125
        print(f"{rouge_l_scores=}")
126
        assert all(
127
128
129
130
131
            score >= rouge_l_tolerance for score in rouge_l_scores
        ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"

    def test_ci_models(self):
        for model_case in CI_MODELS:
132
            for torch_dtype in TORCH_DTYPES:
133
134
                self.assert_close_logits_and_output_strs(
                    DEFAULT_PROMPTS, model_case, torch_dtype
135
136
                )

137
    def test_others(self):
138
139
140
        if is_in_ci():
            return

141
142
143
144
145
146
147
148
149
        for model_case in ALL_OTHER_MODELS:
            if (
                "ONLY_RUN" in os.environ
                and os.environ["ONLY_RUN"] != model_case.model_path
            ):
                continue
            self.assert_close_logits_and_output_strs(
                DEFAULT_PROMPTS, model_case, torch.float16
            )
150

151

152
153
if __name__ == "__main__":
    mp.set_start_method("spawn")
Mingyi's avatar
Mingyi committed
154
    unittest.main()