tokenizer_manager.py 29.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
22
import signal
import sys
23
import time
24
import uuid
25
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
34
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
37
38
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    CloseSessionReqInput,
45
    EmbeddingReqInput,
46
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    GenerateReqInput,
48
49
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
50
51
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
52
53
    OpenSessionReqInput,
    OpenSessionReqOutput,
54
    ProfileReq,
55
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
57
58
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
59
60
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
61
)
62
from sglang.srt.metrics.collector import TokenizerMetricsCollector
63
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
64
from sglang.srt.server_args import PortArgs, ServerArgs
65
from sglang.srt.utils import get_zmq_socket, kill_process_tree
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

69
70
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
73

@dataclasses.dataclass
class ReqState:
74
75
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
    out_list: List
    finished: bool
    event: asyncio.Event

80
81
82
83
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
84
85

class TokenizerManager:
86
87
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90
91
92
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
93
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
94
        self.server_args = server_args
95
        self.enable_metrics = server_args.enable_metrics
Liangsheng Yin's avatar
Liangsheng Yin committed
96

97
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
98
        context = zmq.asyncio.Context(2)
99
100
101
102
103
104
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
105

106
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
107
        self.model_path = server_args.model_path
108
        self.served_model_name = server_args.served_model_name
109
110
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
111
            trust_remote_code=server_args.trust_remote_code,
112
            revision=server_args.revision,
113
114
115
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
116
117
            dtype=server_args.dtype,
            quantization=server_args.quantization,
118
        )
119

120
121
122
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len

123
124
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
125

126
        # Create tokenizer
127
128
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
129
        else:
130
            if self.model_config.is_multimodal:
131
132
133
134
135
136
137
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
138

139
140
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
141
                    self.model_config.hf_config, server_args, self.processor
142
143
144
145
146
147
148
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
149

150
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
151
        self.to_create_loop = True
152
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
153

154
        # For update model weights
155
156
157
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

158
159
160
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

161
162
163
        # Others
        self.gracefully_exit = False

164
165
166
167
168
169
170
171
172
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

173
    async def generate_request(
174
        self,
175
        obj: Union[GenerateReqInput, EmbeddingReqInput],
176
        request: Optional[fastapi.Request] = None,
177
    ):
178
179
        created_time = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
180
        if self.to_create_loop:
181
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
182

183
        while self.model_update_lock.locked():
184
            await asyncio.sleep(0.001)
185

186
187
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
188
189
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
190
191
            )

192
        obj.normalize_batch_and_arguments()
193
        is_single = obj.is_single
194
        if is_single:
195
196
            tokenized_obj = await self._tokenize_one_request(obj)
            self.send_to_scheduler.send_pyobj(tokenized_obj)
197
            async for response in self._wait_one_response(obj, request, created_time):
198
199
                yield response
        else:
200
201
202
            async for response in self._handle_batch_request(
                obj, request, created_time
            ):
203
                yield response
204

205
    async def _tokenize_one_request(
206
        self,
207
        obj: Union[GenerateReqInput, EmbeddingReqInput],
208
    ):
209
210
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
211
        input_embeds = None
212
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
213
214
215
216
217
218
219
220
221
222
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cach` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is None:
223
224
225
226
227
            input_ids = self.tokenizer.encode(input_text)
        else:
            input_ids = obj.input_ids

        if self.is_generation:
228
            # TODO: also support getting embeddings for multimodal models
229
            image_inputs: Dict = await self.image_processor.process_images_async(
230
                obj.image_data, input_text or input_ids, obj
231
            )
232
233
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
234
235
236
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
237
238
            session_id = obj.session[0] if obj.session else None
            session_rid = obj.session[1] if obj.session else None
Lianmin Zheng's avatar
Lianmin Zheng committed
239

Rin Intachuen's avatar
Rin Intachuen committed
240
        if obj.input_ids is not None and len(input_ids) >= self.context_len:
241
242
243
244
245
246
247
248
249
250
251
252
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
253
            tokenized_obj = TokenizedGenerateReqInput(
254
                obj.rid,
255
256
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
257
                image_inputs,
258
259
260
261
262
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
263
264
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
265
266
                session_id=session_id,
                session_rid=session_rid,
267
            )
268
        elif isinstance(obj, EmbeddingReqInput):
269
            tokenized_obj = TokenizedEmbeddingReqInput(
270
                obj.rid,
271
272
273
274
                input_text,
                input_ids,
                sampling_params,
            )
275

276
        return tokenized_obj
277

278
    async def _wait_one_response(
279
        self,
280
        obj: Union[GenerateReqInput, EmbeddingReqInput],
281
        request: Optional[fastapi.Request] = None,
282
        created_time: Optional[float] = None,
283
    ):
284
        """Wait for the response of one request."""
285
        event = asyncio.Event()
286
        state = ReqState([], False, event, created_time=created_time)
287
        self.rid_to_state[obj.rid] = state
288

289
290
        while True:
            try:
291
                await asyncio.wait_for(state.event.wait(), timeout=4)
292
293
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
294
295
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
296
297
                continue

298
            if isinstance(obj, GenerateReqInput):
299
300
                out = self.convert_logprob_style(
                    state.out_list[-1],
301
302
                    obj.return_logprob,
                    obj.top_logprobs_num,
303
304
                    obj.return_text_in_logprobs,
                )
305
            else:  # isinstance(obj, (EmbeddingReqInput,))
306
                out = state.out_list[-1]
307
308
309

            state.out_list = []
            if state.finished:
310
                if self.server_args.log_requests:
311
                    # Log requests
312
                    logger.info(f"in={obj}, out={out}")
313
                del self.rid_to_state[obj.rid]
314
315
316
                yield out
                break

317
            state.event.clear()
318
319
            yield out

320
321
322
323
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
324
        created_time: Optional[float] = None,
325
326
327
328
329
330
331
332
333
334
335
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self.send_to_scheduler.send_pyobj(tokenized_obj)
336
337
338
                generators.append(
                    self._wait_one_response(tmp_obj, request, created_time)
                )
339
340
341
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
344
345
346
347
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )
348
349
350

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
351
352
353
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
354
355
356
357
358
359
360
361
362
363

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self.send_to_scheduler.send_pyobj(tokenized_obj)
364
365
366
                await self._wait_one_response(
                    tmp_obj, request, created_time
                ).__anext__()
367
368
369
370
371
372
373
374

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self.send_to_scheduler.send_pyobj(tokenized_obj)
375
376
377
                    generators.append(
                        self._wait_one_response(tmp_obj, request, created_time)
                    )
378
379
380
381
382
383
384
385
386
387
388
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
389
390
391
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
392
393
394
395
396
397
398
399
400
401
402
403

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

404
405
    def flush_cache(self):
        req = FlushCacheReq()
406
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
407

408
409
410
411
412
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
413
        self.send_to_scheduler.send_pyobj(req)
414

415
416
417
418
419
420
421
422
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

Chayenne's avatar
Chayenne committed
423
424
425
426
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
427
    ):
428
429
430
431
432
433
434
435
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
436

Byron Hsu's avatar
Byron Hsu committed
437
438
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
441
442
443
444
                for i in range(3):
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    # FIXME: We add some sleep here to avoid some race conditions.
                    # We can use a read-write lock as a better fix.
                    await asyncio.sleep(0.01)
Byron Hsu's avatar
Byron Hsu committed
445
446
447
448
                self.send_to_scheduler.send_pyobj(obj)
                self.model_update_result = asyncio.Future()

                if self.server_args.dp_size == 1:
449
450
451
452
453
                    result = await self.model_update_result
                    if result.success:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
Byron Hsu's avatar
Byron Hsu committed
454
                    return result.success, result.message
Chayenne's avatar
Chayenne committed
455
                else:  # self.server_args.dp_size > 1
456
457
458
459
460
461
462
463
464
465
                    self.model_update_tmp = []
                    result = await self.model_update_result

                    all_success = all([r.success for r in result])
                    if all_success is True:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                    all_message = [r.message for r in result]
                    all_message = " | ".join(all_message)
Byron Hsu's avatar
Byron Hsu committed
466
                    return all_success, all_message
467

468
469
470
        else:
            return False, "Another update is in progress. Please try again later."

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> bool:
        if self.to_create_loop:
            self.create_handle_loop()
        self.send_to_scheduler.send_pyobj(obj)

        self.init_weights_update_group_result = asyncio.Future()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
        result = await self.init_weights_update_group_result
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                self.send_to_scheduler.send_pyobj(obj)
                self.parameter_update_result = asyncio.Future()
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be for update weights from distributed"
                result = await self.parameter_update_result
                return result.success, result.message
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
505
            logger.error("Another parameter update is in progress in tokenizer manager")
506
507
508
509
510
            return (
                False,
                "Another parameter update is in progress. Please try again later.",
            )

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        self.send_to_scheduler.send_pyobj(obj)
        self.get_weights_by_name_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.get_weights_by_name_result
            return result.parameter
        else:
            self.get_weights_by_name_tmp = []
            result = await self.get_weights_by_name_result
            all_parameters = [r.parameter for r in result]
            return all_parameters

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        session_id = uuid.uuid4().hex
        obj.session_id = session_id
        self.send_to_scheduler.send_pyobj(obj)
        self.session_futures[session_id] = asyncio.Future()
        session_id = await self.session_futures[session_id]
        del self.session_futures[session_id]
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

Lianmin Zheng's avatar
Lianmin Zheng committed
548
    def create_abort_task(self, obj: GenerateReqInput):
549
550
        # Abort the request if the client is disconnected.
        async def abort_request():
551
            await asyncio.sleep(1)
552
553
554
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
555
                for rid in obj.rid:
556
557
558
559
560
561
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

562
    def create_handle_loop(self):
563
564
565
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
566
567
568
569
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

570
571
572
573
574
575
576
577
578
579
580
581
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        loop.create_task(self.sigterm_watchdog())

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
            await asyncio.sleep(60)

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
582
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
583
584
585
586
587
588
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

589
        kill_process_tree(os.getpid(), include_parent=True)
590
        sys.exit(0)
591

Lianmin Zheng's avatar
Lianmin Zheng committed
592
    async def handle_loop(self):
593
594
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
595
        while True:
596
            recv_obj: Union[
Chayenne's avatar
Chayenne committed
597
598
599
600
                BatchStrOut,
                BatchEmbeddingOut,
                BatchTokenIDOut,
                UpdateWeightFromDiskReqOutput,
601
                UpdateWeightsFromDistributedReqOutput,
602
                GetWeightsByNameReqOutput,
603
                InitWeightsUpdateGroupReqOutput,
604
605
            ] = await self.recv_from_detokenizer.recv_pyobj()

Lianmin Zheng's avatar
Lianmin Zheng committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
            if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
                for i, rid in enumerate(recv_obj.rids):
                    state = self.rid_to_state.get(rid, None)
                    if state is None:
                        continue

                    recv_obj.meta_info[i]["id"] = rid
                    if isinstance(recv_obj, BatchStrOut):
                        out_dict = {
                            "text": recv_obj.output_strs[i],
                            "meta_info": recv_obj.meta_info[i],
                        }
                    elif isinstance(recv_obj, BatchTokenIDOut):
                        out_dict = {
                            "token_ids": recv_obj.output_ids[i],
                            "meta_info": recv_obj.meta_info[i],
                        }
                    else:
                        assert isinstance(recv_obj, BatchEmbeddingOut)
                        out_dict = {
                            "embedding": recv_obj.embeddings[i],
                            "meta_info": recv_obj.meta_info[i],
                        }
                    state.out_list.append(out_dict)
                    state.finished = recv_obj.finished_reason[i] is not None
                    state.event.set()

                    if self.enable_metrics:
                        completion_tokens = recv_obj.meta_info[i]["completion_tokens"]

                        if state.first_token_time is None:
                            state.first_token_time = time.time()
                            self.metrics_collector.observe_time_to_first_token(
                                state.first_token_time - state.created_time
                            )
                        else:
                            if completion_tokens >= 2:
                                self.metrics_collector.observe_time_per_output_token(
                                    (time.time() - state.first_token_time)
                                    / (completion_tokens - 1)
                                )

                        if state.finished:
                            self.metrics_collector.inc_prompt_tokens(
                                recv_obj.meta_info[i]["prompt_tokens"]
                            )
                            self.metrics_collector.inc_generation_tokens(
                                completion_tokens
                            )
                            self.metrics_collector.observe_e2e_request_latency(
                                time.time() - state.created_time
                            )
                            if completion_tokens >= 1:
                                self.metrics_collector.observe_time_per_output_token(
                                    (time.time() - state.created_time)
                                    / completion_tokens
                                )
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
                    recv_obj.session_id
                )
            elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
668
669
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
670
                else:  # self.server_args.dp_size > 1
671
672
673
674
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
Lianmin Zheng's avatar
Lianmin Zheng committed
675
676
677
678
679
            elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for init parameter update group"
                self.init_weights_update_group_result.set_result(recv_obj)
680
681
682
683
684
            elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
                self.parameter_update_result.set_result(recv_obj)
685
686
687
688
689
690
691
692
693
            elif isinstance(recv_obj, GetWeightsByNameReqOutput):
                if self.server_args.dp_size == 1:
                    self.get_weights_by_name_result.set_result(recv_obj)
                else:
                    self.get_weights_by_name_tmp.append(recv_obj)
                    if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
                        self.get_weights_by_name_result.set_result(
                            self.get_weights_by_name_tmp
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
694
695
            else:
                raise ValueError(f"Invalid object: {recv_obj=}")
696

Liangsheng Yin's avatar
Liangsheng Yin committed
697
    def convert_logprob_style(
698
699
700
701
702
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
703
    ):
704
        if return_logprob:
705
706
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
707
            )
708
709
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
710
            )
711
712

            if top_logprobs_num > 0:
713
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
714
                    self.detokenize_top_logprobs_tokens(
715
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
716
717
                        return_text_in_logprobs,
                    )
718
                )
719
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
720
                    self.detokenize_top_logprobs_tokens(
721
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
722
                    )
723
                )
724
725
        return ret

726
727
728
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
729
        # TODO(lianmin): This should run on DetokenizerManager
730
731
732
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

733
        assert self.tokenizer is not None
734
735
736
737
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
738
            for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
739
740
        ]

741
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
742
743
744
745
746
747
748
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
749
        return top_logprobs
750
751
752
753
754
755
756
757
758
759
760


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True