test_skip_tokenizer_init.py 7.7 KB
Newer Older
1
2
3
4
5
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream
"""

6
7
import json
import unittest
8
from io import BytesIO
9
10

import requests
11
12
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
13

14
from sglang.lang.chat_template import get_chat_template_by_model_path
15
from sglang.srt.utils import kill_process_tree
16
from sglang.test.test_utils import (
17
    DEFAULT_IMAGE_URL,
Lianmin Zheng's avatar
Lianmin Zheng committed
18
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
19
    DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
20
21
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
22
    CustomTestCase,
23
24
    popen_launch_server,
)
25
26


27
class TestSkipTokenizerInit(CustomTestCase):
28
29
30
31
32
33
34
35
36
37
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=["--skip-tokenizer-init", "--stream-output"],
        )
38
        cls.eos_token_id = [119690]
39
40
41
42
43
44
45
46
        cls.tokenizer = AutoTokenizer.from_pretrained(
            DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

47
48
49
50
51
52
53
54
    def run_decode(
        self,
        prompt_text="The capital of France is",
        max_new_tokens=32,
        return_logprob=False,
        top_logprobs_num=0,
        n=1,
    ):
55
        input_ids = self.get_input_ids(prompt_text)
56
57

        response = requests.post(
58
            self.base_url + "/generate",
59
            json={
60
                "input_ids": input_ids,
61
62
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
63
                    "max_new_tokens": max_new_tokens,
64
                    "n": n,
65
                    "stop_token_ids": [self.tokenizer.eos_token_id],
66
67
68
69
70
71
72
                },
                "stream": False,
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "logprob_start_len": 0,
            },
        )
73
        ret = response.json()
74
        print(json.dumps(ret, indent=2))
75
76

        def assert_one_item(item):
77
78
79
            if item["meta_info"]["finish_reason"]["type"] == "stop":
                self.assertEqual(
                    item["meta_info"]["finish_reason"]["matched"],
80
                    self.tokenizer.eos_token_id,
81
82
83
                )
            elif item["meta_info"]["finish_reason"]["type"] == "length":
                self.assertEqual(
84
                    len(item["output_ids"]), item["meta_info"]["completion_tokens"]
85
                )
86
                self.assertEqual(len(item["output_ids"]), max_new_tokens)
87
88
89
90
91
92
93
94
95
96
97
98
99
100
                self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))

                if return_logprob:
                    self.assertEqual(
                        len(item["meta_info"]["input_token_logprobs"]),
                        len(input_ids),
                        f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
                    )
                    self.assertEqual(
                        len(item["meta_info"]["output_token_logprobs"]),
                        max_new_tokens,
                    )

        # Determine whether to assert a single item or multiple items based on n
101
102
103
        if n == 1:
            assert_one_item(ret)
        else:
104
            self.assertEqual(len(ret), n)
105
106
107
            for i in range(n):
                assert_one_item(ret[i])

108
109
        print("=" * 100)

110
111
    def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
        max_new_tokens = 32
112
        input_ids = self.get_input_ids("The capital of France is")
113
114
115
116
117
118
119
120
121
        requests.post(self.base_url + "/flush_cache")
        response = requests.post(
            self.base_url + "/generate",
            json={
                "input_ids": input_ids,
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
                    "max_new_tokens": max_new_tokens,
                    "n": n,
122
                    "stop_token_ids": self.eos_token_id,
123
124
125
126
127
128
129
130
131
132
                },
                "stream": False,
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "logprob_start_len": 0,
            },
        )
        ret = response.json()
        print(json.dumps(ret))
        output_ids = ret["output_ids"]
133
134
135
        print("output from non-streaming request:")
        print(output_ids)
        print(self.tokenizer.decode(output_ids, skip_special_tokens=True))
136
137
138
139
140
141
142
143
144
145

        requests.post(self.base_url + "/flush_cache")
        response_stream = requests.post(
            self.base_url + "/generate",
            json={
                "input_ids": input_ids,
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
                    "max_new_tokens": max_new_tokens,
                    "n": n,
146
                    "stop_token_ids": self.eos_token_id,
147
148
149
150
151
152
153
154
155
156
                },
                "stream": True,
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "logprob_start_len": 0,
            },
        )

        response_stream_json = []
        for line in response_stream.iter_lines():
157
            print(line)
158
159
160
161
162
163
164
            if line.startswith(b"data: ") and line[6:] != b"[DONE]":
                response_stream_json.append(json.loads(line[6:]))
        out_stream_ids = []
        for x in response_stream_json:
            out_stream_ids += x["output_ids"]
        print("output from streaming request:")
        print(out_stream_ids)
165
166
        print(self.tokenizer.decode(out_stream_ids, skip_special_tokens=True))

167
168
        assert output_ids == out_stream_ids

169
170
171
172
173
174
175
176
    def test_simple_decode(self):
        self.run_decode()

    def test_parallel_sample(self):
        self.run_decode(n=3)

    def test_logprob(self):
        for top_logprobs_num in [0, 3]:
177
178
179
180
            self.run_decode(return_logprob=True, top_logprobs_num=top_logprobs_num)

    def test_eos_behavior(self):
        self.run_decode(max_new_tokens=256)
181

182
183
184
    def test_simple_decode_stream(self):
        self.run_decode_stream()

185
186
187
188
189
190
191
192
193
194
195
196
197
    def get_input_ids(self, prompt_text) -> list[int]:
        input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
            0
        ].tolist()
        return input_ids


class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
    @classmethod
    def setUpClass(cls):
        cls.image_url = DEFAULT_IMAGE_URL
        response = requests.get(cls.image_url)
        cls.image = Image.open(BytesIO(response.content))
198
        cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.model, use_fast=False)
        cls.processor = AutoProcessor.from_pretrained(cls.model, trust_remote_code=True)
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=["--skip-tokenizer-init"],
        )
        cls.eos_token_id = [cls.tokenizer.eos_token_id]

    def get_input_ids(self, _prompt_text) -> list[int]:
        chat_template = get_chat_template_by_model_path(self.model)
        text = f"{chat_template.image_token}What is in this picture?"
        inputs = self.processor(
            text=[text],
            images=[self.image],
            return_tensors="pt",
        )

        return inputs.input_ids[0].tolist()

    def test_simple_decode_stream(self):
        # TODO mick
        pass

225
226

if __name__ == "__main__":
Lianmin Zheng's avatar
Lianmin Zheng committed
227
    unittest.main()