tokenizer_manager.py 24.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20
import asyncio
import concurrent.futures
import dataclasses
21
import logging
22
import multiprocessing as mp
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import os
24
from typing import Dict, List, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26
27
28
29
30

import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    EmbeddingReqInput,
45
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    GenerateReqInput,
47
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
49
    TokenizedGenerateReqInput,
)
shiyi.c_98's avatar
shiyi.c_98 committed
50
from sglang.srt.mm_utils import expand2square, process_anyres_image
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
53
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
54
from sglang.utils import get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
57

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

58
59
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
66
67
68
69
70
71
72

@dataclasses.dataclass
class ReqState:
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
73
        model_overide_args: dict = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
74
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
75
76
        self.server_args = server_args

Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
        context = zmq.asyncio.Context(2)
        self.recv_from_detokenizer = context.socket(zmq.PULL)
        self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

        self.send_to_router = context.socket(zmq.PUSH)
Mingyi's avatar
Mingyi committed
82
        self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84

        self.model_path = server_args.model_path
85
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
86
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
87
88
89
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
90
        )
91
        self.is_generation = is_generation_model(self.hf_config.architectures)
92

93
94
95
96
        if server_args.context_length is not None:
            self.context_len = server_args.context_length
        else:
            self.context_len = get_context_length(self.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
99
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        else:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            if is_multimodal_model(self.model_path):
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
                self.executor = concurrent.futures.ProcessPoolExecutor(
                    initializer=init_global_processor,
                    mp_context=mp.get_context("fork"),
                    initargs=(server_args,),
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121

        self.to_create_loop = True
122
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
123
124

    async def get_pixel_values(self, image_data):
Ying Sheng's avatar
Ying Sheng committed
125
        aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
126
127
128
        grid_pinpoints = (
            self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
131
        if self.executor is not None:
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(
132
133
134
135
136
                self.executor,
                get_pixel_values,
                image_data,
                aspect_ratio,
                grid_pinpoints,
Lianmin Zheng's avatar
Lianmin Zheng committed
137
138
            )
        else:
139
140
141
            return get_pixel_values(
                image_data, aspect_ratio, grid_pinpoints, self.processor
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
142

143
144
145
    async def generate_request(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
146
        if self.to_create_loop:
147
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
148

149
        obj.post_init()
150
151
        is_single = obj.is_single

152
153
154
155
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
156
            if hasattr(obj, "stream") and obj.stream:
157
                raise ValueError("Do not support stream for batch mode.")
158

159
160
            async for response in self._handle_batch_request(obj, request):
                yield response
161

162
    async def _handle_single_request(
163
164
165
166
167
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request,
        index=None,
        is_cache_for_prefill=False,
168
    ):
yichuan~'s avatar
yichuan~ committed
169
170
171
        if not is_cache_for_prefill:  # The normal case with a single prompt
            not_use_index = index is None

172
173
            rid = obj.rid if not_use_index else obj.rid[index]
            input_text = obj.text if not_use_index else obj.text[index]
174
            if obj.input_ids is None:
175
                assert self.tokenizer is not None
176
177
178
                input_ids = self.tokenizer.encode(input_text)
            else:
                input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
Lianmin Zheng's avatar
Lianmin Zheng committed
179

180
            self._validate_input_length(input_ids)
181

182
            sampling_params = self._get_sampling_params(
183
                obj.sampling_params if not_use_index else obj.sampling_params[index]
184
            )
185
186
187
188
189
190
191
192
193
194
195
196
197

            if self.is_generation:
                pixel_values, image_hash, image_size = await self._get_pixel_values(
                    obj.image_data if not_use_index else obj.image_data[index]
                )
                return_logprob = (
                    obj.return_logprob if not_use_index else obj.return_logprob[index]
                )
                logprob_start_len = (
                    obj.logprob_start_len
                    if not_use_index
                    else obj.logprob_start_len[index]
                )
198
199
200
                if return_logprob and logprob_start_len == -1:
                    logprob_start_len = len(input_ids) - 1

201
202
203
204
205
                top_logprobs_num = (
                    obj.top_logprobs_num
                    if not_use_index
                    else obj.top_logprobs_num[index]
                )
yichuan~'s avatar
yichuan~ committed
206
        else:  # A prefill request to cache the common prompt for parallel sampling
207
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
208
209
210
211
212
213
214
            if obj.text is not None:
                if isinstance(obj.text, list):
                    input_text = obj.text[index]
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
215
216
217
218
219
220
221
222
223
224
225
226
227
228
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
                        input_ids = obj.input_ids[index]
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
229
            else:
yichuan~'s avatar
yichuan~ committed
230
231
232
233
234
235
236
237
238
239
240
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
                    input_ids = obj.input_ids[index]
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

241
242
243
244
245
246
247
248
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
            pixel_values, image_hash, image_size = await self._get_pixel_values(
                obj.image_data[0]
            )
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
249

250
        if self.is_generation:
251
252
            if return_logprob and logprob_start_len == -1:
                logprob_start_len = len(input_ids) - 1
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
                pixel_values,
                image_hash,
                image_size,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
            )
        else:  # is embedding
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )

274
275
276
277
278
        self.send_to_router.send_pyobj(tokenized_obj)

        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
279
        if not is_cache_for_prefill:
280
281
282
283
            async for response in self._wait_for_response(
                event, state, obj, rid, request
            ):
                yield response
284
        else:
Ying Sheng's avatar
Ying Sheng committed
285
            assert self.is_generation
286
287
            await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
            yield input_ids
288

289
290
291
    async def _handle_batch_request(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
    ):
292
        batch_size = obj.batch_size
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
                # Send prefill requests to cache the common input
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
                        obj, request, index=i, is_cache_for_prefill=True
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
                if input_id_result is not None and len(input_id_result) > 1:
                    obj.input_ids = input_id_result
                elif input_id_result is not None:
                    obj.input_ids = input_id_result[0]
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
312

313
314
315
316
        # First send out all requests
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
317
                    continue
318
319
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
320
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
321
322
323
324
325
326
327
328
329
330
331
                    index += batch_size - 1 - i
                rid = obj.rid[index]
                if parallel_sample_num == 1:
                    ## select operation
                    if obj.input_ids is None:
                        input_text = obj.text[i]
                        input_ids = self.tokenizer.encode(obj.text[i])
                    else:
                        input_text = None
                        input_ids = obj.input_ids[i]
                else:
yichuan~'s avatar
yichuan~ committed
332
                    assert obj.input_ids is not None
333
                    if batch_size == 1:
yichuan~'s avatar
yichuan~ committed
334
                        input_text = None
335
336
                        input_ids = obj.input_ids
                    else:
yichuan~'s avatar
yichuan~ committed
337
                        input_text = None
338
339
                        input_ids = obj.input_ids[i]
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
340

341
                if self.is_generation:
342
343
                    if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
                        obj.logprob_start_len[index] = len(input_ids) - 1
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
                    pixel_values, image_hash, image_size = await self._get_pixel_values(
                        obj.image_data[index]
                    )

                    tokenized_obj = TokenizedGenerateReqInput(
                        rid,
                        input_text,
                        input_ids,
                        pixel_values,
                        image_hash,
                        image_size,
                        sampling_params,
                        obj.return_logprob[index],
                        obj.logprob_start_len[index],
                        obj.top_logprobs_num[index],
                        obj.stream,
                    )
                else:
                    tokenized_obj = TokenizedEmbeddingReqInput(
                        rid,
                        input_text,
                        input_ids,
                        sampling_params,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
368
369
370
                self.send_to_router.send_pyobj(tokenized_obj)

                event = asyncio.Event()
371
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
372
                self.rid_to_state[rid] = state
373

374
375
376
377
378
379
380
381
382
383
        # Then wait for all responses
        output_list = []
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
                    continue
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
                    index += batch_size - 1 - i
                rid = obj.rid[index]
Lianmin Zheng's avatar
Lianmin Zheng committed
384
                state = self.rid_to_state[rid]
385
386
387

                while True:
                    try:
388
                        await asyncio.wait_for(state.event.wait(), timeout=4)
389
390
391
392
393
394
395
                        break
                    except asyncio.TimeoutError:
                        if request is not None and await request.is_disconnected():
                            for rid in obj.rid:
                                self.abort_request(rid)
                            raise ValueError(f"Abort request {rid}")
                        continue
396
397
398
399
400
401
402
403
                if self.is_generation:
                    output_list.append(
                        self.convert_logprob_style(
                            state.out_list[-1],
                            obj.return_logprob[index],
                            obj.top_logprobs_num[index],
                            obj.return_text_in_logprobs,
                        )
Liangsheng Yin's avatar
Liangsheng Yin committed
404
                    )
405
406
                else:
                    output_list.append(state.out_list[-1])
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
                assert state.finished
                del self.rid_to_state[rid]
409
410
        yield output_list

411
    def _validate_input_length(self, input_ids: List[int]):
412
413
414
415
416
417
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

418
    def _get_sampling_params(self, sampling_params_data: dict):
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

    async def _get_pixel_values(self, image_data):
        if isinstance(image_data, list) and len(image_data) > 0:
            return await self.get_pixel_values(image_data[0])
        elif isinstance(image_data, str):
            return await self.get_pixel_values(image_data)
        else:
            return None, None, None

433
434
435
436
    async def _wait_for_response(
        self,
        event: asyncio.Event,
        state: ReqState,
437
        obj: Union[GenerateReqInput, EmbeddingReqInput],
438
439
440
        rid: str,
        request,
    ):
441
442
443
444
445
446
447
448
449
        while True:
            try:
                await asyncio.wait_for(event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

450
451
452
453
454
455
456
457
458
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
                    obj.return_logprob,
                    obj.top_logprobs_num,
                    obj.return_text_in_logprobs,
                )
            else:  # isinstance(obj, EmbeddingReqInput)
                out = state.out_list[-1]
459

Ying Sheng's avatar
Ying Sheng committed
460
            # Log requests
461
            if self.server_args.log_requests and state.finished:
Ying Sheng's avatar
Ying Sheng committed
462
                if obj.text is None:
463
                    in_obj = {"input_ids": obj.input_ids}
Ying Sheng's avatar
Ying Sheng committed
464
465
466
                else:
                    in_obj = {"text": obj.text}
                logger.info(f"in={in_obj}, out={out}")
467
468
469
470
471
472
473
474
475
476

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

            event.clear()
            yield out

477
478
479
480
481
482
483
484
    async def _wait_for_cache_prefill_response(
        self,
        event: asyncio.Event,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
        request,
    ):
485
486
487
488
489
490
491
492
493
494
495
496
497
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
498

499
500
501
    def flush_cache(self):
        req = FlushCacheReq()
        self.send_to_router.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
502

503
    def abort_request(self, rid: str):
504
505
        if rid not in self.rid_to_state:
            return
506
507
508
509
        del self.rid_to_state[rid]
        req = AbortReq(rid)
        self.send_to_router.send_pyobj(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
510
    def create_abort_task(self, obj: GenerateReqInput):
511
512
513
514
515
516
        # Abort the request if the client is disconnected.
        async def abort_request():
            await asyncio.sleep(3)
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
517
                for rid in obj.rid:
518
519
520
521
522
523
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

524
    def create_handle_loop(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
525
526
527
528
529
530
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
        while True:
531
            recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
532
533
                await self.recv_from_detokenizer.recv_pyobj()
            )
534
535
536
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
537
538
539
540
541
542
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
543
544
545
546
547
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
548
549
550
551
552
553
554
555
556
                elif isinstance(recv_obj, BatchTokenIDOut):
                    read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
                    out_dict = {
                        "token_ids": recv_obj.decode_ids[
                            read_start : recv_obj.read_offsets[i]
                        ],
                        "meta_info": recv_obj.meta_info[i],
                    }

557
558
559
560
561
562
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
563
564
565
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
566

Liangsheng Yin's avatar
Liangsheng Yin committed
567
    def convert_logprob_style(
568
569
570
571
572
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
573
    ):
574
        if return_logprob:
575
576
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
577
            )
578
579
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
580
            )
581
582

            if top_logprobs_num > 0:
583
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
584
                    self.detokenize_top_logprobs_tokens(
585
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
586
587
                        return_text_in_logprobs,
                    )
588
                )
589
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
590
                    self.detokenize_top_logprobs_tokens(
591
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
592
                    )
593
                )
594
595
        return ret

596
597
598
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
599
600
601
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

602
        assert self.tokenizer is not None
603
604
605
606
607
608
609
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

610
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
611
612
613
614
615
616
617
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
618
        return top_logprobs
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639


global global_processor


def init_global_processor(server_args: ServerArgs):
    global global_processor
    transformers.logging.set_verbosity_error()
    global_processor = get_processor(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )


def get_pixel_values(
    image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
    try:
        processor = processor or global_processor
        image, image_size = load_image(image_data)
640
        if image_size is not None:
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
            image_hash = hash(image_data)
            pixel_values = processor.image_processor(image)["pixel_values"]
            for _ in range(len(pixel_values)):
                pixel_values[_] = pixel_values[_].astype(np.float16)
            pixel_values = np.stack(pixel_values, axis=0)
            return pixel_values, image_hash, image_size
        else:
            image_hash = hash(image_data)
            if image_aspect_ratio == "pad":
                image = expand2square(
                    image,
                    tuple(int(x * 255) for x in processor.image_processor.image_mean),
                )
                pixel_values = processor.image_processor(image)["pixel_values"][0]
            elif image_aspect_ratio == "anyres":
                pixel_values = process_anyres_image(
                    image, processor.image_processor, image_grid_pinpoints
                )
            else:
                pixel_values = processor.image_processor(image)["pixel_values"][0]
            pixel_values = pixel_values.astype(np.float16)
            return pixel_values, image_hash, image.size
    except Exception:
664
        print("Exception in TokenizerManager:\n" + get_exception_traceback())