tokenizer_manager.py 27.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20
import asyncio
import concurrent.futures
import dataclasses
21
import logging
22
import multiprocessing as mp
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import os
24
from typing import Dict, List, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26
27
28
29
30

import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    EmbeddingReqInput,
45
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    GenerateReqInput,
47
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    TokenizedGenerateReqInput,
49
50
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
51
)
shiyi.c_98's avatar
shiyi.c_98 committed
52
from sglang.srt.mm_utils import expand2square, process_anyres_image
53
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
54
from sglang.srt.server_args import PortArgs, ServerArgs
55
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
56
from sglang.utils import get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

60
61
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64

@dataclasses.dataclass
class ReqState:
65
66
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
70
71
72
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
73
74
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
77
78
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
79
        model_overide_args: dict = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
81
82
        self.server_args = server_args

Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85
86
87
        context = zmq.asyncio.Context(2)
        self.recv_from_detokenizer = context.socket(zmq.PULL)
        self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

        self.send_to_router = context.socket(zmq.PUSH)
Mingyi's avatar
Mingyi committed
88
        self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
Lianmin Zheng's avatar
Lianmin Zheng committed
89
90

        self.model_path = server_args.model_path
91
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
92
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
93
94
95
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
            model_overide_args=model_overide_args,
Lianmin Zheng's avatar
Lianmin Zheng committed
96
        )
97
        self.is_generation = is_generation_model(self.hf_config.architectures)
98

99
100
101
102
        if server_args.context_length is not None:
            self.context_len = server_args.context_length
        else:
            self.context_len = get_context_length(self.hf_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
103

104
105
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
106
        else:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            if is_multimodal_model(self.model_path):
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
                self.executor = concurrent.futures.ProcessPoolExecutor(
                    initializer=init_global_processor,
                    mp_context=mp.get_context("fork"),
                    initargs=(server_args,),
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
126
127

        self.to_create_loop = True
128
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
131
132
133
        # for update model weights
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

134
135
136
137
138
139
    async def get_pixel_values(self, image_data, aspect_ratio=None):
        aspect_ratio = (
            getattr(self.hf_config, "image_aspect_ratio", None)
            if aspect_ratio is None
            else aspect_ratio
        )
140
        grid_pinpoints = (
141
142
143
144
            self.hf_config.image_grid_pinpoints
            if hasattr(self.hf_config, "image_grid_pinpoints")
            and "anyres" in aspect_ratio
            else None
145
        )
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

        if isinstance(image_data, list) and len(image_data) > 0:
            pixel_values, image_hash, image_size = [], [], []
            if len(image_data) > 1:
                aspect_ratio = "pad"  # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
                for img_data in image_data:
                    pixel_v, image_h, image_s = await self._process_single_image(
                        img_data, aspect_ratio, grid_pinpoints
                    )
                    pixel_values.append(pixel_v)
                    image_hash.append(image_h)
                    image_size.append(image_s)
                pixel_values = np.stack(pixel_values, axis=0)
            else:
                pixel_values, image_hash, image_size = await self._process_single_image(
                    image_data[0], aspect_ratio, grid_pinpoints
                )
                image_hash = [image_hash]
                image_size = [image_size]
        elif isinstance(image_data, str):
            pixel_values, image_hash, image_size = await self._process_single_image(
                image_data, aspect_ratio, grid_pinpoints
            )
            image_hash = [image_hash]
            image_size = [image_size]
        else:
            pixel_values, image_hash, image_size = None, None, None

        return pixel_values, image_hash, image_size

    async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
Lianmin Zheng's avatar
Lianmin Zheng committed
177
178
179
        if self.executor is not None:
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(
180
181
182
183
184
                self.executor,
                get_pixel_values,
                image_data,
                aspect_ratio,
                grid_pinpoints,
Lianmin Zheng's avatar
Lianmin Zheng committed
185
186
            )
        else:
187
188
189
            return get_pixel_values(
                image_data, aspect_ratio, grid_pinpoints, self.processor
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
190

191
192
193
    async def generate_request(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
194
        if self.to_create_loop:
195
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
196

197
198
199
        while self.model_update_lock.locked():
            await asyncio.sleep(0)

200
        obj.post_init()
201
202
        is_single = obj.is_single

203
204
205
206
207
208
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
209

210
    async def _handle_single_request(
211
212
213
214
215
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request,
        index=None,
        is_cache_for_prefill=False,
216
    ):
yichuan~'s avatar
yichuan~ committed
217
218
219
        if not is_cache_for_prefill:  # The normal case with a single prompt
            not_use_index = index is None

220
221
            rid = obj.rid if not_use_index else obj.rid[index]
            input_text = obj.text if not_use_index else obj.text[index]
222
            if obj.input_ids is None:
223
                assert self.tokenizer is not None
224
225
226
                input_ids = self.tokenizer.encode(input_text)
            else:
                input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
Lianmin Zheng's avatar
Lianmin Zheng committed
227

228
            self._validate_input_length(input_ids)
229

230
            sampling_params = self._get_sampling_params(
231
                obj.sampling_params if not_use_index else obj.sampling_params[index]
232
            )
233
234

            if self.is_generation:
235
236
                pixel_values, image_hash, image_size = await self.get_pixel_values(
                    obj.image_data
237
238
239
240
241
242
243
244
245
                )
                return_logprob = (
                    obj.return_logprob if not_use_index else obj.return_logprob[index]
                )
                logprob_start_len = (
                    obj.logprob_start_len
                    if not_use_index
                    else obj.logprob_start_len[index]
                )
246
247
248
                if return_logprob and logprob_start_len == -1:
                    logprob_start_len = len(input_ids) - 1

249
250
251
252
253
                top_logprobs_num = (
                    obj.top_logprobs_num
                    if not_use_index
                    else obj.top_logprobs_num[index]
                )
yichuan~'s avatar
yichuan~ committed
254
        else:  # A prefill request to cache the common prompt for parallel sampling
255
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
256
257
258
259
260
261
262
            if obj.text is not None:
                if isinstance(obj.text, list):
                    input_text = obj.text[index]
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
                        input_ids = obj.input_ids[index]
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
277
            else:
yichuan~'s avatar
yichuan~ committed
278
279
280
281
282
283
284
285
286
287
288
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
                    input_ids = obj.input_ids[index]
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

289
290
291
292
293
294
295
296
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
            pixel_values, image_hash, image_size = await self._get_pixel_values(
                obj.image_data[0]
            )
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
297

298
        if self.is_generation:
299
300
            if return_logprob and logprob_start_len == -1:
                logprob_start_len = len(input_ids) - 1
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
                pixel_values,
                image_hash,
                image_size,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
            )
        else:  # is embedding
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )

322
323
324
325
326
        self.send_to_router.send_pyobj(tokenized_obj)

        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
327
        if not is_cache_for_prefill:
328
329
330
331
            async for response in self._wait_for_response(
                event, state, obj, rid, request
            ):
                yield response
332
        else:
Ying Sheng's avatar
Ying Sheng committed
333
            assert self.is_generation
334
335
            await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
            yield input_ids
336

337
338
339
    async def _handle_batch_request(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
    ):
340
        batch_size = obj.batch_size
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
                # Send prefill requests to cache the common input
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
                        obj, request, index=i, is_cache_for_prefill=True
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
                if input_id_result is not None and len(input_id_result) > 1:
                    obj.input_ids = input_id_result
                elif input_id_result is not None:
                    obj.input_ids = input_id_result[0]
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
360

361
        # First send out all requests
362
        generators = []
363
364
365
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
366
                    continue
367
368
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
369
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
370
371
372
373
374
375
376
377
378
379
380
                    index += batch_size - 1 - i
                rid = obj.rid[index]
                if parallel_sample_num == 1:
                    ## select operation
                    if obj.input_ids is None:
                        input_text = obj.text[i]
                        input_ids = self.tokenizer.encode(obj.text[i])
                    else:
                        input_text = None
                        input_ids = obj.input_ids[i]
                else:
yichuan~'s avatar
yichuan~ committed
381
                    assert obj.input_ids is not None
382
                    if batch_size == 1:
yichuan~'s avatar
yichuan~ committed
383
                        input_text = None
384
385
                        input_ids = obj.input_ids
                    else:
yichuan~'s avatar
yichuan~ committed
386
                        input_text = None
387
388
                        input_ids = obj.input_ids[i]
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
389

390
                if self.is_generation:
391
392
                    if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
                        obj.logprob_start_len[index] = len(input_ids) - 1
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
                    pixel_values, image_hash, image_size = await self._get_pixel_values(
                        obj.image_data[index]
                    )

                    tokenized_obj = TokenizedGenerateReqInput(
                        rid,
                        input_text,
                        input_ids,
                        pixel_values,
                        image_hash,
                        image_size,
                        sampling_params,
                        obj.return_logprob[index],
                        obj.logprob_start_len[index],
                        obj.top_logprobs_num[index],
                        obj.stream,
                    )
                else:
                    tokenized_obj = TokenizedEmbeddingReqInput(
                        rid,
                        input_text,
                        input_ids,
                        sampling_params,
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
419
                self.send_to_router.send_pyobj(tokenized_obj)

                event = asyncio.Event()
420
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
421
                self.rid_to_state[rid] = state
422

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
                generators.append(
                    self._wait_for_response(
                        event,
                        state,
                        obj,
                        rid,
                        request,
                        index=index,
                        response_index=len(generators),
                    )
                )

        # Then process the responses based on streaming option

        is_stream = hasattr(obj, "stream") and obj.stream

        tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
440
        output_list = [None] * len(tasks)
441
442
443
444
445

        while tasks:
            done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

            for task in done:
446
                cur_index = tasks.index(task)
447
448
449
450
451
452
453

                try:
                    result = task.result()

                    if is_stream:
                        yield result
                    else:
454
                        output_list[result["index"]] = result
455

456
457
                    tasks[cur_index] = asyncio.create_task(
                        generators[cur_index].__anext__()
Liangsheng Yin's avatar
Liangsheng Yin committed
458
                    )
459
                except StopAsyncIteration:
460
461
                    del generators[cur_index]
                    del tasks[cur_index]
462
463
464

        if not is_stream:
            yield output_list
465

466
    def _validate_input_length(self, input_ids: List[int]):
467
468
469
470
471
472
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

473
    def _get_sampling_params(self, sampling_params_data: dict):
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

    async def _get_pixel_values(self, image_data):
        if isinstance(image_data, list) and len(image_data) > 0:
            return await self.get_pixel_values(image_data[0])
        elif isinstance(image_data, str):
            return await self.get_pixel_values(image_data)
        else:
            return None, None, None

488
489
490
491
    async def _wait_for_response(
        self,
        event: asyncio.Event,
        state: ReqState,
492
        obj: Union[GenerateReqInput, EmbeddingReqInput],
493
494
        rid: str,
        request,
495
496
        index: int = None,
        response_index: int = 0,
497
    ):
498
499
500
501
502
        while True:
            try:
                await asyncio.wait_for(event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
503
504
                    for rid in [obj.rid] if obj.is_single else obj.rid:
                        self.abort_request(rid)
505
506
507
                    raise ValueError(f"Abort request {rid}")
                continue

508
509
510
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
511
512
513
514
515
516
                    obj.return_logprob if index is None else obj.return_logprob[index],
                    (
                        obj.top_logprobs_num
                        if index is None
                        else obj.top_logprobs_num[index]
                    ),
517
518
519
520
                    obj.return_text_in_logprobs,
                )
            else:  # isinstance(obj, EmbeddingReqInput)
                out = state.out_list[-1]
521

522
523
            out["index"] = response_index

Ying Sheng's avatar
Ying Sheng committed
524
            # Log requests
525
            if self.server_args.log_requests and state.finished:
526
                logger.info(f"in={obj}, out={out}")
527
528
529
530
531
532
533
534
535
536

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

            event.clear()
            yield out

537
538
539
540
541
542
543
544
    async def _wait_for_cache_prefill_response(
        self,
        event: asyncio.Event,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
        request,
    ):
545
546
547
548
549
550
551
552
553
554
555
556
557
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
558

559
560
561
    def flush_cache(self):
        req = FlushCacheReq()
        self.send_to_router.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    async def update_weights(self, obj: UpdateWeightReqInput, request):
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
                while len(self.rid_to_state) > 0:
                    await asyncio.sleep(0)
                self.send_to_router.send_pyobj(obj)
                self.model_update_result = asyncio.Future()
                result = await self.model_update_result
                if result.success:
                    self.server_args.model_path = obj.model_path
                    self.server_args.load_format = obj.load_format
                    self.model_path = obj.model_path
            return result.success, result.message
        else:
            return False, "Another update is in progress. Please try again later."

587
    def abort_request(self, rid: str):
588
589
        if rid not in self.rid_to_state:
            return
590
591
592
593
        del self.rid_to_state[rid]
        req = AbortReq(rid)
        self.send_to_router.send_pyobj(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
594
    def create_abort_task(self, obj: GenerateReqInput):
595
596
597
598
599
600
        # Abort the request if the client is disconnected.
        async def abort_request():
            await asyncio.sleep(3)
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
601
                for rid in obj.rid:
602
603
604
605
606
607
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

608
    def create_handle_loop(self):
609
610
611
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
614
615
616
617
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
        while True:
618
619
620
621
622
623
624
625
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
                self.model_update_result.set_result(recv_obj)
                continue

626
627
628
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
629

630
631
632
633
634
635
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
636
637
638
639
640
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
641
642
643
644
645
646
647
648
649
                elif isinstance(recv_obj, BatchTokenIDOut):
                    read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
                    out_dict = {
                        "token_ids": recv_obj.decode_ids[
                            read_start : recv_obj.read_offsets[i]
                        ],
                        "meta_info": recv_obj.meta_info[i],
                    }

650
651
652
653
654
655
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
656
657
658
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
659

Liangsheng Yin's avatar
Liangsheng Yin committed
660
    def convert_logprob_style(
661
662
663
664
665
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
666
    ):
667
        if return_logprob:
668
669
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
670
            )
671
672
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
673
            )
674
675

            if top_logprobs_num > 0:
676
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
677
                    self.detokenize_top_logprobs_tokens(
678
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
679
680
                        return_text_in_logprobs,
                    )
681
                )
682
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
683
                    self.detokenize_top_logprobs_tokens(
684
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
685
                    )
686
                )
687
688
        return ret

689
690
691
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
692
693
694
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

695
        assert self.tokenizer is not None
696
697
698
699
700
701
702
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

703
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
704
705
706
707
708
709
710
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
711
        return top_logprobs
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732


global global_processor


def init_global_processor(server_args: ServerArgs):
    global global_processor
    transformers.logging.set_verbosity_error()
    global_processor = get_processor(
        server_args.tokenizer_path,
        tokenizer_mode=server_args.tokenizer_mode,
        trust_remote_code=server_args.trust_remote_code,
    )


def get_pixel_values(
    image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
    try:
        processor = processor or global_processor
        image, image_size = load_image(image_data)
733
        if image_size is not None:
734
735
736
737
738
739
740
741
742
743
744
745
746
747
            image_hash = hash(image_data)
            pixel_values = processor.image_processor(image)["pixel_values"]
            for _ in range(len(pixel_values)):
                pixel_values[_] = pixel_values[_].astype(np.float16)
            pixel_values = np.stack(pixel_values, axis=0)
            return pixel_values, image_hash, image_size
        else:
            image_hash = hash(image_data)
            if image_aspect_ratio == "pad":
                image = expand2square(
                    image,
                    tuple(int(x * 255) for x in processor.image_processor.image_mean),
                )
                pixel_values = processor.image_processor(image)["pixel_values"][0]
748
            elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
749
750
751
752
753
754
755
756
                pixel_values = process_anyres_image(
                    image, processor.image_processor, image_grid_pinpoints
                )
            else:
                pixel_values = processor.image_processor(image)["pixel_values"][0]
            pixel_values = pixel_values.astype(np.float16)
            return pixel_values, image_hash, image.size
    except Exception:
757
        print("Exception in TokenizerManager:\n" + get_exception_traceback())