tokenizer_manager.py 38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
import pickle
22
23
import signal
import sys
24
import threading
25
import time
26
import uuid
27
from datetime import datetime
28
from http import HTTPStatus
29
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
30

31
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
34
import uvloop
import zmq
import zmq.asyncio
35
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
36

37
from sglang.srt.aio_rwlock import RWLock
38
39
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
41
42
43
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
44
from sglang.srt.managers.io_struct import (
45
    AbortReq,
46
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    BatchStrOut,
48
    BatchTokenIDOut,
49
    CloseSessionReqInput,
50
    ConfigureLoggingReq,
51
    EmbeddingReqInput,
52
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    GenerateReqInput,
54
55
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
56
57
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
58
59
    OpenSessionReqInput,
    OpenSessionReqOutput,
60
    ProfileReq,
61
62
63
64
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
65
    SessionParams,
66
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
67
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
68
69
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
70
71
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
72
73
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
74
)
75
from sglang.srt.metrics.collector import TokenizerMetricsCollector
76
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
77
from sglang.srt.server_args import PortArgs, ServerArgs
78
79
80
81
82
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
83
from sglang.utils import get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

87
88
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
91

@dataclasses.dataclass
class ReqState:
92
93
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96
    out_list: List
    finished: bool
    event: asyncio.Event
Lianmin Zheng's avatar
Lianmin Zheng committed
97
    obj: Any
Lianmin Zheng's avatar
Lianmin Zheng committed
98

99
100
101
102
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

103
104
105
    # For streaming output
    last_output_offset: int = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
106
107

class TokenizerManager:
108
109
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
110
111
112
113
114
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
115
        # Parse args
Mick's avatar
Mick committed
116

Liangsheng Yin's avatar
Liangsheng Yin committed
117
        self.server_args = server_args
118
        self.enable_metrics = server_args.enable_metrics
119
        self.log_requests = server_args.log_requests
Liangsheng Yin's avatar
Liangsheng Yin committed
120

121
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
122
        context = zmq.asyncio.Context(2)
123
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
124
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
125
126
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
127
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
128
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
131
        self.model_path = server_args.model_path
132
        self.served_model_name = server_args.served_model_name
133
134
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
135
            trust_remote_code=server_args.trust_remote_code,
136
            revision=server_args.revision,
137
138
139
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
140
141
            dtype=server_args.dtype,
            quantization=server_args.quantization,
142
        )
143

144
145
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len
146
        self.image_token_id = self.model_config.image_token_id
147

148
149
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
150

151
        # Create tokenizer
152
153
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
154
        else:
155
            if self.model_config.is_multimodal:
156
157
158
159
160
161
162
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
163

164
165
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
166
                    self.model_config.hf_config, server_args, self.processor
167
168
169
170
171
172
173
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
174

175
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
176
        self.to_create_loop = True
177
        self.rid_to_state: Dict[str, ReqState] = {}
178
179
180
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
181

182
183
184
185
186
187
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
188

189
190
191
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

192
193
        # Others
        self.gracefully_exit = False
194
195
196
197
198
199
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
200
201
202
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
203
204
205
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
206
207
208
209
210
211
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
Mick's avatar
Mick committed
212
213
        # Set after scheduler is initialized
        self.max_req_input_len = None
214

215
216
217
218
219
220
221
222
223
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

224
    async def generate_request(
225
        self,
226
        obj: Union[GenerateReqInput, EmbeddingReqInput],
227
        request: Optional[fastapi.Request] = None,
228
    ):
229
230
        created_time = time.time()

231
        self.auto_create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
232

233
234
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
235
236
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
237
238
            )

239
        obj.normalize_batch_and_arguments()
240

241
        if self.log_requests:
242
243
244
245
246
247
            logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")

        async with self.model_update_lock.reader_lock:
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
248
249
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
250
251
252
253
254
255
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response
256

257
    async def _tokenize_one_request(
258
        self,
259
        obj: Union[GenerateReqInput, EmbeddingReqInput],
260
    ):
261
262
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
263
        input_embeds = None
264
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
265
266
267
268
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
269
                    "Please add `--disable-radix-cache` when you launch the server "
Rin Intachuen's avatar
Rin Intachuen committed
270
271
272
273
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
274
        elif obj.input_ids is not None:
275
            input_ids = obj.input_ids
276
277
278
279
280
281
282
283
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)
284
285

        if self.is_generation:
286
            # TODO: also support getting embeddings for multimodal models
287
            image_inputs: Dict = await self.image_processor.process_images_async(
Mick's avatar
Mick committed
288
                obj.image_data, input_text or input_ids, obj, self.max_req_input_len
289
            )
290
291
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
292
293
294
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
295
296
297
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
298

299
300
        input_token_num = len(input_ids) if input_ids is not None else 0
        if input_token_num >= self.context_len:
301
            raise ValueError(
302
                f"The input ({input_token_num} tokens) is longer than the "
303
304
305
                f"model's context length ({self.context_len} tokens)."
            )

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        if (
            obj.sampling_params.get("max_new_tokens") is not None
            and obj.sampling_params.get("max_new_tokens") + input_token_num
            >= self.context_len
        ):
            raise ValueError(
                f"Requested token count exceeds the model's maximum context length "
                f"of {self.context_len} tokens. You requested a total of "
                f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
                f"tokens: {input_token_num} tokens from the input messages and "
                f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
                f"completion. Please reduce the number of tokens in the input "
                f"messages or the completion to fit within the limit."
            )

321
322
323
324
325
326
327
        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
328
            tokenized_obj = TokenizedGenerateReqInput(
329
                obj.rid,
330
331
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
332
                image_inputs,
333
334
335
336
337
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
338
339
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
340
                session_params=session_params,
341
            )
342
        elif isinstance(obj, EmbeddingReqInput):
343
            tokenized_obj = TokenizedEmbeddingReqInput(
344
                obj.rid,
345
346
347
348
                input_text,
                input_ids,
                sampling_params,
            )
349

350
        return tokenized_obj
351

352
    def _send_one_request(
353
        self,
354
        obj: Union[GenerateReqInput, EmbeddingReqInput],
355
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
356
        created_time: Optional[float] = None,
357
    ):
358
        event = asyncio.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
359
        state = ReqState([], False, event, obj, created_time=created_time)
360
        self.rid_to_state[obj.rid] = state
361
362
363
364
365
366
367
368
369
        self.send_to_scheduler.send_pyobj(tokenized_obj)

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]
370

371
372
        while True:
            try:
373
                await asyncio.wait_for(state.event.wait(), timeout=4)
374
375
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
376
377
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
378
379
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
380
            out = state.out_list[-1]
381
382
383

            state.out_list = []
            if state.finished:
384
                if self.log_requests:
385
386
                    msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
                    logger.info(msg)
387
                del self.rid_to_state[obj.rid]
388
389
390
391
392
393
394
395
396
397

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

398
399
400
                yield out
                break

401
            state.event.clear()
Lianmin Zheng's avatar
Lianmin Zheng committed
402
403
404
405
406
407
408

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
409

410
411
412
413
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
414
        created_time: Optional[float] = None,
415
416
417
418
419
420
421
422
423
424
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
425
426
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                generators.append(self._wait_one_response(tmp_obj, request))
427
428
429
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
433
434
435
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )
436
437
438

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
439
440
441
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
442
443
444
445
446
447
448
449
450

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
451
452
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()
453
454
455
456
457
458
459

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
460
461
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
462
463
464
465
466
467
468
469
470
471
472
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
473
474
475
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
476
477
478
479
480
481
482
483
484
485
486
487

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

488
489
    def flush_cache(self):
        req = FlushCacheReq()
490
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
491

492
493
494
495
496
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
497
        self.send_to_scheduler.send_pyobj(req)
498

499
500
501
502
503
504
505
506
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

Chayenne's avatar
Chayenne committed
507
508
509
510
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
511
    ) -> Tuple[bool, str]:
512
        self.auto_create_handle_loop()
513
514
515
516

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
517
        logger.info("Start update_weights. Load format=%s", obj.load_format)
518

519
520
521
522
523
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
524

525
526
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
527
    ) -> Tuple[bool, str]:
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            return result.success, result.message
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
            return all_success, all_message
550

551
552
553
554
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
555
    ) -> Tuple[bool, str]:
556
        self.auto_create_handle_loop()
557
558
559
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
560
        result = (await self.init_weights_update_group_communicator(obj))[0]
561
562
563
564
565
566
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
567
    ) -> Tuple[bool, str]:
568
569
570
571
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"
572

573
574
575
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
576
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
577
            return result.success, result.message
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

595
596
597
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
598
599
600
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
601
        if self.server_args.dp_size == 1:
602
            return all_parameters[0]
603
604
605
        else:
            return all_parameters

606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

622
623
624
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
625
        self.auto_create_handle_loop()
626

627
628
629
630
631
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

632
        self.send_to_scheduler.send_pyobj(obj)
633
634
635
636

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
637
638
639
640
641
642
643
644
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

645
646
647
648
649
650
651
652
653
    def configure_logging(self, obj: ConfigureLoggingReq):
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")

Lianmin Zheng's avatar
Lianmin Zheng committed
654
    def create_abort_task(self, obj: GenerateReqInput):
655
656
        # Abort the request if the client is disconnected.
        async def abort_request():
657
            await asyncio.sleep(1)
658
659
660
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
661
                for rid in obj.rid:
662
663
664
665
666
667
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

668
    def auto_create_handle_loop(self):
669
670
671
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
672
673
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
674
675
676
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
677

678
679
680
681
682
683
684
685
686
687
688
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
            loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
689
690
691
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
692
693
694

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
695
            await asyncio.sleep(5)
696

697
        # Drain requests
698
699
700
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
701
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
702
703
704
705
706
707
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

708
        kill_process_tree(os.getpid(), include_parent=True)
709
        sys.exit(0)
710

Lianmin Zheng's avatar
Lianmin Zheng committed
711
    async def handle_loop(self):
712
713
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
714
        while True:
715
            recv_obj: Union[
Chayenne's avatar
Chayenne committed
716
717
718
719
                BatchStrOut,
                BatchEmbeddingOut,
                BatchTokenIDOut,
                UpdateWeightFromDiskReqOutput,
720
                UpdateWeightsFromDistributedReqOutput,
721
                GetWeightsByNameReqOutput,
722
                InitWeightsUpdateGroupReqOutput,
723
724
                ReleaseMemoryOccupationReqOutput,
                ResumeMemoryOccupationReqOutput,
725
726
            ] = await self.recv_from_detokenizer.recv_pyobj()

Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
729
730
731
732
            if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
                for i, rid in enumerate(recv_obj.rids):
                    state = self.rid_to_state.get(rid, None)
                    if state is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
                    meta_info = {
                        "id": rid,
                        "finish_reason": recv_obj.finished_reasons[i],
                        "prompt_tokens": recv_obj.prompt_tokens[i],
                    }

                    if getattr(state.obj, "return_logprob", False):
                        self.convert_logprob_style(
                            meta_info,
                            state.obj.top_logprobs_num,
                            state.obj.return_text_in_logprobs,
                            recv_obj,
                            i,
                        )

748
749
750
751
752
753
754
755
                    if not isinstance(recv_obj, BatchEmbeddingOut):
                        meta_info.update(
                            {
                                "completion_tokens": recv_obj.completion_tokens[i],
                                "cached_tokens": recv_obj.cached_tokens[i],
                            }
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
756
757
758
                    if isinstance(recv_obj, BatchStrOut):
                        out_dict = {
                            "text": recv_obj.output_strs[i],
759
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
762
763
                        }
                    elif isinstance(recv_obj, BatchTokenIDOut):
                        out_dict = {
                            "token_ids": recv_obj.output_ids[i],
764
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
765
766
767
768
769
                        }
                    else:
                        assert isinstance(recv_obj, BatchEmbeddingOut)
                        out_dict = {
                            "embedding": recv_obj.embeddings[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
770
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
771
772
                        }
                    state.out_list.append(out_dict)
Lianmin Zheng's avatar
Lianmin Zheng committed
773
                    state.finished = recv_obj.finished_reasons[i] is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
774
775
                    state.event.set()

776
                    if self.enable_metrics and state.obj.log_metrics:
777
                        self.collect_metrics(state, recv_obj, i)
778
779
780
781
782
                    if (
                        self.dump_requests_folder
                        and state.finished
                        and state.obj.log_metrics
                    ):
783
                        self.dump_requests(state, out_dict)
Lianmin Zheng's avatar
Lianmin Zheng committed
784
785
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
786
                    recv_obj.session_id if recv_obj.success else None
Lianmin Zheng's avatar
Lianmin Zheng committed
787
788
                )
            elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
789
790
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
791
                else:  # self.server_args.dp_size > 1
792
793
794
795
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
Lianmin Zheng's avatar
Lianmin Zheng committed
796
797
798
799
            elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for init parameter update group"
800
                self.init_weights_update_group_communicator.handle_recv(recv_obj)
801
802
803
804
            elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
805
                self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
806
807
808
809
810
            elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
                self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
811
            elif isinstance(recv_obj, GetWeightsByNameReqOutput):
812
                self.get_weights_by_name_communicator.handle_recv(recv_obj)
813
814
815
816
            elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
                self.release_memory_occupation_communicator.handle_recv(recv_obj)
            elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
                self.resume_memory_occupation_communicator.handle_recv(recv_obj)
Lianmin Zheng's avatar
Lianmin Zheng committed
817
818
            else:
                raise ValueError(f"Invalid object: {recv_obj=}")
819

Liangsheng Yin's avatar
Liangsheng Yin committed
820
    def convert_logprob_style(
821
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
822
        meta_info: dict,
823
824
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Lianmin Zheng's avatar
Lianmin Zheng committed
825
826
        recv_obj: BatchStrOut,
        recv_obj_index: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
827
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.input_token_logprobs_val[recv_obj_index],
            recv_obj.input_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.output_token_logprobs_val[recv_obj_index],
            recv_obj.output_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_top_logprobs_val[recv_obj_index],
                recv_obj.input_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
844
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
845
846
847
848
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.output_top_logprobs_val[recv_obj_index],
                recv_obj.output_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
849
            )
850

851
    def detokenize_logprob_tokens(
Lianmin Zheng's avatar
Lianmin Zheng committed
852
853
854
855
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
856
    ):
857
        if not decode_to_text:
Lianmin Zheng's avatar
Lianmin Zheng committed
858
859
860
861
862
863
864
865
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
866

Lianmin Zheng's avatar
Lianmin Zheng committed
867
868
869
870
871
872
    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
873
874
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
Lianmin Zheng's avatar
Lianmin Zheng committed
875
876
877
878
879
880
881
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
882
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
883
884
885
            else:
                ret.append(None)
        return ret
886

887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

        if state.first_token_time is None:
            state.first_token_time = time.time()
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
            if completion_tokens >= 2:
                # Compute time_per_output_token for the streaming case
                self.metrics_collector.observe_time_per_output_token(
                    (time.time() - state.first_token_time) / (completion_tokens - 1)
                )

        if state.finished:
            self.metrics_collector.observe_one_finished_request(
                recv_obj.prompt_tokens[i], completion_tokens
            )
            self.metrics_collector.observe_e2e_request_latency(
                time.time() - state.created_time
            )
            # Compute time_per_output_token for the non-streaming case
            if (
                hasattr(state.obj, "stream")
                and not state.obj.stream
                and completion_tokens >= 1
            ):
                self.metrics_collector.observe_time_per_output_token(
                    (time.time() - state.created_time) / completion_tokens
                )

923
924
925
926
927
928
    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
929
930
931
932
933
934
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

935
936
937
938
939
            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
940
                with open(filename, "wb") as f:
941
942
943
944
945
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

946

947
948
949
950
951
952
953
954
955
956
957
958
959
960
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
        logger.error(f"TokenizerManager hit an exception: {traceback}")
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


961
962
963
964
965
966
967
968
969
class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994


T = TypeVar("T")


class _Communicator(Generic[T]):
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
        self._result_future: Optional[asyncio.Future] = None
        self._result_values: Optional[List[T]] = None

    async def __call__(self, obj):
        self._sender.send_pyobj(obj)
        self._result_future = asyncio.Future()
        self._result_values = []
        await self._result_future
        result_values = self._result_values
        self._result_future = self._result_values = None
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
            self._result_future.set_result(None)