tokenizer_manager.py 19.1 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
"""TokenizerManager is a process that tokenizes the text."""
2

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
import asyncio
import concurrent.futures
import dataclasses
6
import logging
7
import multiprocessing as mp
Lianmin Zheng's avatar
Lianmin Zheng committed
8
import os
9
from typing import Dict, List
Lianmin Zheng's avatar
Lianmin Zheng committed
10
11
12
13
14
15

import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
16
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20
21
22
23
24
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
from sglang.srt.managers.io_struct import (
25
    AbortReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
26
    BatchStrOut,
27
    BatchTokenIDOut,
28
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
31
    GenerateReqInput,
    TokenizedGenerateReqInput,
)
shiyi.c_98's avatar
shiyi.c_98 committed
32
from sglang.srt.mm_utils import expand2square, process_anyres_image
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
35
36
from sglang.srt.utils import is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
37
38
39

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

40
41
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
45
46
47
48
49
50
51
52
53
54

@dataclasses.dataclass
class ReqState:
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
55
        model_overide_args: dict = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
57
58
        self.server_args = server_args

Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61
62
63
        context = zmq.asyncio.Context(2)
        self.recv_from_detokenizer = context.socket(zmq.PULL)
        self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

        self.send_to_router = context.socket(zmq.PUSH)
Mingyi's avatar
Mingyi committed
64
        self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67

        self.model_path = server_args.model_path
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
68
69
70
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
71
        )
72
73
74
75
        if server_args.context_length is not None:
            self.context_len = server_args.context_length
        else:
            self.context_len = get_context_length(self.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
80
81
82
83
84
85

        if is_multimodal_model(self.model_path):
            self.processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )
            self.tokenizer = self.processor.tokenizer
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
            self.executor = concurrent.futures.ProcessPoolExecutor(
86
87
88
                initializer=init_global_processor,
                mp_context=mp.get_context("fork"),
                initargs=(server_args,),
Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
91
92
93
94
95
96
97
            )
        else:
            self.tokenizer = get_tokenizer(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
            )

        self.to_create_loop = True
98
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100

    async def get_pixel_values(self, image_data):
Ying Sheng's avatar
Ying Sheng committed
101
        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
102
103
104
        grid_pinpoints = (
            self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107
        if self.executor is not None:
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(
108
109
110
111
112
                self.executor,
                get_pixel_values,
                image_data,
                aspect_ratio,
                grid_pinpoints,
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114
            )
        else:
115
116
117
            return get_pixel_values(
                image_data, aspect_ratio, grid_pinpoints, self.processor
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
118

119
    async def generate_request(self, obj: GenerateReqInput, request=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        if self.to_create_loop:
121
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
122

123
        obj.post_init()
124
125
        is_single = obj.is_single

126
127
128
129
130
131
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            if obj.stream:
                raise ValueError("Do not support stream for batch mode.")
132

133
134
            async for response in self._handle_batch_request(obj, request):
                yield response
135

136
137
138
139
    async def _handle_single_request(
        self, obj, request, index=None, is_cache_for_prefill=False
    ):
        if not is_cache_for_prefill:
140
141
142
143
144
145
146
147
148
            rid = obj.rid if index is None else obj.rid[index]
            input_text = obj.text if index is None else obj.text[index]
            input_ids = (
                self.tokenizer.encode(input_text)
                if obj.input_ids is None
                else obj.input_ids
            )
            if index is not None and obj.input_ids:
                input_ids = obj.input_ids[index]
Lianmin Zheng's avatar
Lianmin Zheng committed
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
            self._validate_input_length(input_ids)
            sampling_params = self._get_sampling_params(
                obj.sampling_params if index is None else obj.sampling_params[index]
            )
            pixel_values, image_hash, image_size = await self._get_pixel_values(
                obj.image_data if index is None else obj.image_data[index]
            )
            return_logprob = (
                obj.return_logprob if index is None else obj.return_logprob[index]
            )
            logprob_start_len = (
                obj.logprob_start_len if index is None else obj.logprob_start_len[index]
            )
            top_logprobs_num = (
                obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
            )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        else:
            if isinstance(obj.text, list):
                input_text = obj.text[index]
                rid = obj.rid[index]
            else:
                input_text = obj.text
                rid = obj.rid[0]
            input_ids = self.tokenizer.encode(input_text)
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
            pixel_values, image_hash, image_size = await self._get_pixel_values(
                obj.image_data[0]
            )
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        tokenized_obj = TokenizedGenerateReqInput(
            rid,
            input_text,
            input_ids,
            pixel_values,
            image_hash,
            image_size,
            sampling_params,
            return_logprob,
            logprob_start_len,
            top_logprobs_num,
            obj.stream,
        )
        self.send_to_router.send_pyobj(tokenized_obj)

        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
201
        if not is_cache_for_prefill:
202
203
204
205
            async for response in self._wait_for_response(
                event, state, obj, rid, request
            ):
                yield response
206
207
208
        else:
            await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
            yield input_ids
209

210
    async def _handle_batch_request(self, obj: GenerateReqInput, request):
211
212
213
214
        batch_size = obj.batch_size
        parallel_sample_num = obj.sampling_params[0].get("n", 1)

        if parallel_sample_num != 1:
215
            # Send prefill requests to cache the common input
216
217
218
219
            parallel_sample_num += 1
            input_id_result = [] if obj.input_ids is None else None
            for i in range(batch_size):
                async for input_id in self._handle_single_request(
220
                    obj, request, index=i, is_cache_for_prefill=True
221
222
223
224
225
226
227
228
                ):
                    if input_id_result is not None:
                        input_id_result.append(input_id)
                    pass
            if len(input_id_result) > 1 and input_id_result is not None:
                obj.input_ids = input_id_result
            elif input_id_result is not None:
                obj.input_ids = input_id_result[0]
229

230
231
232
233
        # First send out all requests
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
234
                    continue
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
                    # Here when using parallel sampling we shoul consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
                    index += batch_size - 1 - i
                rid = obj.rid[index]
                if parallel_sample_num == 1:
                    ## select operation
                    if obj.input_ids is None:
                        input_text = obj.text[i]
                        input_ids = self.tokenizer.encode(obj.text[i])
                    else:
                        input_text = None
                        input_ids = obj.input_ids[i]
                else:
                    if batch_size == 1:
                        input_text = obj.text
                        input_ids = obj.input_ids
                    else:
                        input_text = obj.text[i]
                        input_ids = obj.input_ids[i]
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
                pixel_values, image_hash, image_size = await self._get_pixel_values(
                    obj.image_data[index]
Liangsheng Yin's avatar
Liangsheng Yin committed
258
                )
259

Lianmin Zheng's avatar
Lianmin Zheng committed
260
                tokenized_obj = TokenizedGenerateReqInput(
261
262
263
264
265
266
267
268
269
270
271
                    rid,
                    input_text,
                    input_ids,
                    pixel_values,
                    image_hash,
                    image_size,
                    sampling_params,
                    obj.return_logprob[index],
                    obj.logprob_start_len[index],
                    obj.top_logprobs_num[index],
                    obj.stream,
Lianmin Zheng's avatar
Lianmin Zheng committed
272
273
274
275
                )
                self.send_to_router.send_pyobj(tokenized_obj)

                event = asyncio.Event()
276
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
                self.rid_to_state[rid] = state

279
280
281
282
283
284
285
286
287
288
        # Then wait for all responses
        output_list = []
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
                    continue
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
                    index += batch_size - 1 - i
                rid = obj.rid[index]
Lianmin Zheng's avatar
Lianmin Zheng committed
289
                state = self.rid_to_state[rid]
290
291
292

                while True:
                    try:
293
                        await asyncio.wait_for(state.event.wait(), timeout=4)
294
295
296
297
298
299
300
                        break
                    except asyncio.TimeoutError:
                        if request is not None and await request.is_disconnected():
                            for rid in obj.rid:
                                self.abort_request(rid)
                            raise ValueError(f"Abort request {rid}")
                        continue
301
                output_list.append(
Liangsheng Yin's avatar
Liangsheng Yin committed
302
303
                    self.convert_logprob_style(
                        state.out_list[-1],
304
305
                        obj.return_logprob[index],
                        obj.top_logprobs_num[index],
Liangsheng Yin's avatar
Liangsheng Yin committed
306
307
308
                        obj.return_text_in_logprobs,
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
                assert state.finished
                del self.rid_to_state[rid]

312
313
        yield output_list

314
    def _validate_input_length(self, input_ids: List[int]):
315
316
317
318
319
320
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

321
    def _get_sampling_params(self, sampling_params_data: dict):
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

    async def _get_pixel_values(self, image_data):
        if isinstance(image_data, list) and len(image_data) > 0:
            return await self.get_pixel_values(image_data[0])
        elif isinstance(image_data, str):
            return await self.get_pixel_values(image_data)
        else:
            return None, None, None

336
337
338
339
340
341
342
343
    async def _wait_for_response(
        self,
        event: asyncio.Event,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
        request,
    ):
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        while True:
            try:
                await asyncio.wait_for(event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

            out = self.convert_logprob_style(
                state.out_list[-1],
                obj.return_logprob,
                obj.top_logprobs_num,
                obj.return_text_in_logprobs,
            )

            if self.server_args.log_requests and state.finished:
                logger.info(f"in={obj.text}, out={out}")

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

            event.clear()
            yield out

372
373
374
375
376
377
378
379
    async def _wait_for_cache_prefill_response(
        self,
        event: asyncio.Event,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
        request,
    ):
380
381
382
383
384
385
386
387
388
389
390
391
392
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
393

394
395
396
    def flush_cache(self):
        req = FlushCacheReq()
        self.send_to_router.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
397

398
    def abort_request(self, rid: str):
399
400
        if rid not in self.rid_to_state:
            return
401
402
403
404
        del self.rid_to_state[rid]
        req = AbortReq(rid)
        self.send_to_router.send_pyobj(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
405
    def create_abort_task(self, obj: GenerateReqInput):
406
407
408
409
410
411
412
413
414
415
416
417
418
        # Abort the request if the client is disconnected.
        async def abort_request():
            await asyncio.sleep(3)
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
                for rid in obj.rids:
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

419
    def create_handle_loop(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
422
423
424
425
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
        while True:
426
427
            recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
            assert isinstance(recv_obj, BatchStrOut)
Lianmin Zheng's avatar
Lianmin Zheng committed
428

429
430
431
432
433
434
435
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
                out_dict = {
436
                    "text": recv_obj.output_strs[i],
437
438
439
440
441
                    "meta_info": recv_obj.meta_info[i],
                }
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
442

Liangsheng Yin's avatar
Liangsheng Yin committed
443
    def convert_logprob_style(
444
445
446
447
448
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
449
    ):
450
        if return_logprob:
451
452
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
453
            )
454
455
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
456
            )
457
458

            if top_logprobs_num > 0:
459
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
460
                    self.detokenize_top_logprobs_tokens(
461
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
462
463
                        return_text_in_logprobs,
                    )
464
                )
465
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
466
                    self.detokenize_top_logprobs_tokens(
467
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
468
                    )
469
                )
470
471
        return ret

472
    def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
473
474
475
476
477
478
479
480
481
482
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

483
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
484
485
486
487
        for i, t in enumerate(top_logprobs):
            if t:
                top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
        return top_logprobs
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508


global global_processor


def init_global_processor(server_args: ServerArgs):
    global global_processor
    transformers.logging.set_verbosity_error()
    global_processor = get_processor(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )


def get_pixel_values(
    image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
    try:
        processor = processor or global_processor
        image, image_size = load_image(image_data)
509
        if image_size is not None:
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            image_hash = hash(image_data)
            pixel_values = processor.image_processor(image)["pixel_values"]
            for _ in range(len(pixel_values)):
                pixel_values[_] = pixel_values[_].astype(np.float16)
            pixel_values = np.stack(pixel_values, axis=0)
            return pixel_values, image_hash, image_size
        else:
            image_hash = hash(image_data)
            if image_aspect_ratio == "pad":
                image = expand2square(
                    image,
                    tuple(int(x * 255) for x in processor.image_processor.image_mean),
                )
                pixel_values = processor.image_processor(image)["pixel_values"][0]
            elif image_aspect_ratio == "anyres":
                pixel_values = process_anyres_image(
                    image, processor.image_processor, image_grid_pinpoints
                )
            else:
                pixel_values = processor.image_processor(image)["pixel_values"][0]
            pixel_values = pixel_values.astype(np.float16)
            return pixel_values, image_hash, image.size
    except Exception:
533
        print("Exception in TokenizerManager:\n" + get_exception_traceback())