tokenizer_manager.py 45.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
import pickle
22
23
import signal
import sys
24
import threading
25
import time
26
import uuid
27
from collections import deque
28
29
from datetime import datetime
from http import HTTPStatus
30
31
32
33
34
35
36
37
38
39
40
41
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
46
import uvloop
import zmq
import zmq.asyncio
47
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
48

49
from sglang.srt.aio_rwlock import RWLock
50
from sglang.srt.configs.model_config import ModelConfig
51
52
53
54
55
56
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
57
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.managers.io_struct import (
59
    AbortReq,
60
    BatchEmbeddingOut,
61
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
62
    BatchStrOut,
63
    BatchTokenIDOut,
64
    CloseSessionReqInput,
65
    ConfigureLoggingReq,
66
    EmbeddingReqInput,
67
    ExpertDistributionReq,
68
    ExpertDistributionReqOutput,
69
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
70
    GenerateReqInput,
71
72
    GetInternalStateReq,
    GetInternalStateReqOutput,
73
74
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
75
    HealthCheckOutput,
76
77
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
78
79
    OpenSessionReqInput,
    OpenSessionReqOutput,
80
    ProfileReq,
81
82
    ProfileReqOutput,
    ProfileReqType,
83
84
85
86
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
87
88
89
    SessionParams,
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
90
91
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
92
93
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
94
95
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
96
)
Mick's avatar
Mick committed
97
98
99
100
101
from sglang.srt.managers.multimodal_processor import (
    get_dummy_processor,
    get_mm_processor,
    import_processors,
)
102
103
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
104
from sglang.srt.server_args import PortArgs, ServerArgs
105
106
107
108
109
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
110
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
113

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

114
115
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
116

117
118
119
120
121
122
123
124
125
126
127
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

    out_list: List
    finished: bool
    event: asyncio.Event
    obj: Any

    # For metrics
    created_time: float
128
129
130
131
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
132
133
134
135
136

    # For streaming output
    last_output_offset: int = 0


137
138
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
139

Lianmin Zheng's avatar
Lianmin Zheng committed
140
141
142
143
144
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
145
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
146
        self.server_args = server_args
147
        self.enable_metrics = server_args.enable_metrics
148
        self.log_requests = server_args.log_requests
149
        self.log_requests_level = server_args.log_requests_level
Liangsheng Yin's avatar
Liangsheng Yin committed
150

151
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
152
        context = zmq.asyncio.Context(2)
153
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
154
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
155
156
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
157
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
158
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
159

160
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
161
        self.model_path = server_args.model_path
162
        self.served_model_name = server_args.served_model_name
163
164
165
166
167
168
169
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
170
            enable_multimodal=server_args.enable_multimodal,
171
172
173
174
175
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )

        self.is_generation = self.model_config.is_generation
176
        self.is_image_gen = self.model_config.is_image_gen
177
178
179
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id

180
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
181
            import_processors()
182
183
184
185
186
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
187
                use_fast=not server_args.disable_fast_image_processor,
188
189
190
            )

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
191
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
192
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
193
            self.mm_processor = get_mm_processor(
194
195
196
197
198
199
200
                self.model_config.hf_config, server_args, _processor
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
201
202
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
203
        else:
Mick's avatar
Mick committed
204
            self.mm_processor = get_dummy_processor()
205

206
207
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
208
209
210
211
212
213
214
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
215

216
        # Store states
217
        self.no_create_loop = False
218
        self.rid_to_state: Dict[str, ReqState] = {}
219
220
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
221
222
223
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
224
        self.log_request_metadata = self.get_log_request_metadata()
Lianmin Zheng's avatar
Lianmin Zheng committed
225

226
227
228
229
230
231
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
232

233
234
235
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

236
237
238
239
240
241
242
243
244
245
246
247
248
        # Set after scheduler is initialized
        self.max_req_input_len = None

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

        # Communicators
249
250
251
252
253
254
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
255
256
257
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
258
259
260
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
261
262
263
264
265
266
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
267
268
269
270
271
272
        self.start_profile_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
273
274
275
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
276

277
        self._result_dispatcher = TypeBasedDispatcher(
278
            [
279
                (
280
281
282
283
284
285
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
286
                    self._handle_batch_output,
287
                ),
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
317
318
319
320
321
322
323
324
                (
                    ProfileReqOutput,
                    self.start_profile_communicator.handle_recv,
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
325
326
327
328
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
329
                (HealthCheckOutput, lambda x: None),
330
331
332
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
333
334
335
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
336
337
338
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )
Byron Hsu's avatar
Byron Hsu committed
339
340
341
        # for disaggregtion, start kv boostrap server on prefill
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
342
343
344
345
            kv_bootstrap_server_class = get_kv_class(
                self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
            )
            self.bootstrap_server = kv_bootstrap_server_class(
Byron Hsu's avatar
Byron Hsu committed
346
347
348
                self.server_args.disaggregation_bootstrap_port
            )

349
    async def generate_request(
350
        self,
351
        obj: Union[GenerateReqInput, EmbeddingReqInput],
352
        request: Optional[fastapi.Request] = None,
353
    ):
354
355
        created_time = time.time()

356
        self.auto_create_handle_loop()
357
358
359
360
361
362
363
364
365
366

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        obj.normalize_batch_and_arguments()

        if self.log_requests:
367
            max_length, skip_names, _ = self.log_request_metadata
368
            logger.info(
369
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
370
371
            )

372
        async with self.model_update_lock.reader_lock:
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)

Mick's avatar
Mick committed
413
        image_inputs: Dict = await self.mm_processor.process_mm_data_async(
414
415
416
417
            obj.image_data, input_text or input_ids, obj, self.max_req_input_len
        )
        if image_inputs and "input_ids" in image_inputs:
            input_ids = image_inputs["input_ids"]
418
419
420
421
        if self.is_generation:
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
422
            token_ids_logprob = obj.token_ids_logprob
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

        input_token_num = len(input_ids) if input_ids is not None else 0
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        if (
            obj.sampling_params.get("max_new_tokens") is not None
            and obj.sampling_params.get("max_new_tokens") + input_token_num
            >= self.context_len
        ):
            raise ValueError(
                f"Requested token count exceeds the model's maximum context length "
                f"of {self.context_len} tokens. You requested a total of "
                f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
                f"tokens: {input_token_num} tokens from the input messages and "
                f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
                f"completion. Please reduce the number of tokens in the input "
                f"messages or the completion to fit within the limit."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
                image_inputs,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
465
                token_ids_logprob,
466
                obj.stream,
467
468
                bootstrap_host=obj.bootstrap_host,
                bootstrap_room=obj.bootstrap_room,
469
470
471
472
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
473
                return_hidden_states=obj.return_hidden_states,
474
475
476
477
478
479
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
480
                image_inputs,
481
482
483
484
485
486
487
488
489
490
491
                sampling_params,
            )

        return tokenized_obj

    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
492
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        self.rid_to_state[obj.rid] = state
        self.send_to_scheduler.send_pyobj(tokenized_obj)

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]

        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
510
511
512
513
                    raise ValueError(
                        "Request is disconnected from the client side. "
                        f"Abort request {obj.rid}"
                    )
514
515
516
517
518
519
520
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
521
522
523
524
525
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                    logger.info(msg)
                del self.rid_to_state[obj.rid]

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
548
549
550
551
                    raise ValueError(
                        "Request is disconnected from the client side. "
                        f"Abort request {obj.rid}"
                    )
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                generators.append(self._wait_one_response(tmp_obj, request))
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
630

631
632
    def flush_cache(self):
        req = FlushCacheReq()
633
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
634

635
    def abort_request(self, rid: str):
636
637
638
639
640
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
        self.send_to_scheduler.send_pyobj(req)
641

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
    ):
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
            num_steps=num_steps,
            activities=activities,
        )
        result = (await self.start_profile_communicator(req))[0]
        if not result.success:
            raise RuntimeError(result.message)
        return result
658
659

    def stop_profile(self):
660
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
661
662
        self.send_to_scheduler.send_pyobj(req)

663
664
    async def start_expert_distribution_record(self):
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
665

666
667
    async def stop_expert_distribution_record(self):
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
668

669
670
    async def dump_expert_distribution_record(self):
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
671

Chayenne's avatar
Chayenne committed
672
673
674
675
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
676
    ) -> Tuple[bool, str]:
677
        self.auto_create_handle_loop()
678
679
680
681

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
682
        logger.info("Start update_weights. Load format=%s", obj.load_format)
683

684
685
686
687
688
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
689

690
691
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
692
    ) -> Tuple[bool, str]:
693
694
695
696
697
698
699
700
701
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
702
            return result.success, result.message, result.num_paused_requests
703
704
705
706
707
708
709
710
711
712
713
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
714
715
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
716

717
718
719
720
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
721
    ) -> Tuple[bool, str]:
722
        self.auto_create_handle_loop()
723
724
725
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
726
        result = (await self.init_weights_update_group_communicator(obj))[0]
727
728
729
730
731
732
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
733
    ) -> Tuple[bool, str]:
734
735
736
737
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"
738

739
740
741
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
742
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
743
            return result.success, result.message
744

745
746
747
748
749
750
751
752
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
753
        ), "dp_size must be 1 for update weights from distributed"
754
755
756
757
758
759
760

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

761
762
763
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
764
765
766
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
767
        if self.server_args.dp_size == 1:
768
            return all_parameters[0]
769
770
771
        else:
            return all_parameters

772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

788
789
790
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
791
        self.auto_create_handle_loop()
792

793
794
795
796
797
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

798
        self.send_to_scheduler.send_pyobj(obj)
799
800
801
802

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
803
804
805
806
807
808
809
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
    async def get_internal_state(self) -> Dict[Any, Any]:
        req = GetInternalStateReq()
        res: List[GetInternalStateReqOutput] = (
            await self.get_internal_state_communicator(req)
        )
        return res[0].internal_state

    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
                    ]
                )
            elif self.log_requests_level == 1:
                max_length = 2048
            elif self.log_requests_level == 2:
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

850
    def configure_logging(self, obj: ConfigureLoggingReq):
851
852
853
854
855
856
857
858
859
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")
860
        self.log_request_metadata = self.get_log_request_metadata()
861

Lianmin Zheng's avatar
Lianmin Zheng committed
862
    def create_abort_task(self, obj: GenerateReqInput):
863
864
        # Abort the request if the client is disconnected.
        async def abort_request():
865
            await asyncio.sleep(1)
866
867
868
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
869
                for rid in obj.rid:
870
871
872
873
874
875
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

876
    def auto_create_handle_loop(self):
877
        if self.no_create_loop:
878
879
            return

880
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
881
        loop = asyncio.get_event_loop()
882
883
884
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
885

886
887
888
889
890
891
892
893
894
895
896
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
            loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
897
898
899
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
900
901
902

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
903
            await asyncio.sleep(5)
904

905
        # Drain requests
906
        while True:
907
            remain_num_req = len(self.rid_to_state)
908
            logger.info(
909
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
910
911
912
913
914
915
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

916
        kill_process_tree(os.getpid(), include_parent=True)
917
        sys.exit(0)
918

Lianmin Zheng's avatar
Lianmin Zheng committed
919
    async def handle_loop(self):
920
921
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
922
        while True:
923
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
924
            self._result_dispatcher(recv_obj)
925
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
926

927
    def _handle_batch_output(
928
929
930
931
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
932
933
934
935
936
937
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
                continue

938
            # Build meta_info and return value
939
940
941
942
943
944
945
946
947
948
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
                    state.obj.top_logprobs_num,
949
                    state.obj.token_ids_logprob,
950
951
952
953
954
955
956
957
958
959
960
961
962
                    state.obj.return_text_in_logprobs,
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

963
            if getattr(recv_obj, "output_hidden_states", None):
964
965
966
967
968
969
970
971
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
                out_dict = {
                    "text": recv_obj.output_strs[i],
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
972
973
974
975
976
977
978
979
                if self.server_args.stream_output and state.obj.stream:
                    output_token_ids = recv_obj.output_ids[i][
                        state.last_output_offset :
                    ]
                    state.last_output_offset = len(recv_obj.output_ids[i])
                else:
                    output_token_ids = recv_obj.output_ids[i]

980
                out_dict = {
981
                    "output_ids": output_token_ids,
982
983
                    "meta_info": meta_info,
                }
984
985
            elif isinstance(recv_obj, BatchMultimodalOut):
                raise NotImplementedError()
986
987
988
989
990
991
992
993
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
994
995
996
997
998
999
1000
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time

            state.out_list.append(out_dict)
1001
1002
            state.event.set()

1003
            # Log metrics and dump
1004
1005
1006
1007
1008
1009
1010
1011
1012
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)

    def convert_logprob_style(
        self,
        meta_info: dict,
        top_logprobs_num: int,
1013
        token_ids_logprob: List[int],
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.input_token_logprobs_val[recv_obj_index],
            recv_obj.input_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.output_token_logprobs_val[recv_obj_index],
            recv_obj.output_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_top_logprobs_val[recv_obj_index],
                recv_obj.input_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.output_top_logprobs_val[recv_obj_index],
                recv_obj.output_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
            )

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        if token_ids_logprob is not None:
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_token_ids_logprobs_val[recv_obj_index],
                recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
                    recv_obj.output_token_ids_logprobs_val[recv_obj_index],
                    recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
                    return_text_in_logprobs,
                )
            )

1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1098
1099
1100
        if state.first_token_time == 0.0:
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1101
1102
1103
1104
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1105
1106
1107
1108
1109
1110
1111
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1112
                )
1113
1114
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1115
1116
1117

        if state.finished:
            self.metrics_collector.observe_one_finished_request(
1118
1119
                recv_obj.prompt_tokens[i],
                completion_tokens,
1120
                recv_obj.cached_tokens[i],
1121
                state.finished_time - state.created_time,
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
            # set future if the all results are recevied
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1161

1162
1163
1164
1165
1166
1167
1168
1169
1170
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1171
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1172
1173
1174
1175
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1176
class SignalHandler:
1177
    def __init__(self, tokenizer_manager: TokenizerManager):
1178
        self.tokenizer_manager = tokenizer_manager
1179
1180
1181
1182
1183

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1184
        self.tokenizer_manager.gracefully_exit = True
1185
1186
1187
1188
1189
1190


T = TypeVar("T")


class _Communicator(Generic[T]):
1191
1192
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1193
1194
1195
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1196
        self._result_event: Optional[asyncio.Event] = None
1197
        self._result_values: Optional[List[T]] = None
1198
        self._ready_queue: Deque[asyncio.Future] = deque()
1199
1200

    async def __call__(self, obj):
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1212
        self._result_values = []
1213
        await self._result_event.wait()
1214
        result_values = self._result_values
1215
1216
1217
1218
1219
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1220
1221
1222
1223
1224
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1225
            self._result_event.set()