tokenizer_manager.py 57.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import os
22
import pickle
23
24
import signal
import sys
25
import threading
26
import time
27
import uuid
28
from collections import deque
29
30
from datetime import datetime
from http import HTTPStatus
31
32
33
34
35
36
37
38
39
40
41
42
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
43

44
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
import uvloop
import zmq
import zmq.asyncio
48
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
49

50
from sglang.srt.aio_rwlock import RWLock
51
from sglang.srt.configs.model_config import ModelConfig
52
53
54
55
56
57
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
58
59
60
61
62
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
63
from sglang.srt.managers.io_struct import (
64
    AbortReq,
65
    BatchEmbeddingOut,
66
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
67
    BatchStrOut,
68
    BatchTokenIDOut,
69
    CloseSessionReqInput,
70
    ConfigureLoggingReq,
71
    EmbeddingReqInput,
72
    ExpertDistributionReq,
73
    ExpertDistributionReqOutput,
74
75
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
    GenerateReqInput,
77
78
    GetInternalStateReq,
    GetInternalStateReqOutput,
79
80
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
81
    HealthCheckOutput,
82
83
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
84
85
    OpenSessionReqInput,
    OpenSessionReqOutput,
86
    ProfileReq,
87
88
    ProfileReqOutput,
    ProfileReqType,
89
90
91
92
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
93
    SessionParams,
94
95
    SetInternalStateReq,
    SetInternalStateReqOutput,
96
97
    SlowDownReqInput,
    SlowDownReqOutput,
98
99
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
100
101
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
102
103
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
104
105
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
106
)
Mick's avatar
Mick committed
107
108
109
110
111
from sglang.srt.managers.multimodal_processor import (
    get_dummy_processor,
    get_mm_processor,
    import_processors,
)
112
113
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
114
from sglang.srt.server_args import PortArgs, ServerArgs
115
116
117
118
119
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
120
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
121
122
123

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

124
125
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
126

127
128
129
130
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

131
    out_list: List[Dict[Any, Any]]
132
133
    finished: bool
    event: asyncio.Event
134
    obj: Union[GenerateReqInput, EmbeddingReqInput]
135
136
137

    # For metrics
    created_time: float
138
139
140
141
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
142
143
144

    # For streaming output
    last_output_offset: int = 0
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    # For incremental state update.
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
160
161


162
163
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
164

Lianmin Zheng's avatar
Lianmin Zheng committed
165
166
167
168
169
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
170
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
171
        self.server_args = server_args
172
        self.enable_metrics = server_args.enable_metrics
173
        self.log_requests = server_args.log_requests
174
        self.log_requests_level = server_args.log_requests_level
175
176
177
178
179
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
180

181
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
182
        context = zmq.asyncio.Context(2)
183
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
184
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
185
186
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
187
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
188
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
189

190
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
191
        self.model_path = server_args.model_path
192
        self.served_model_name = server_args.served_model_name
193
        self.model_config = ModelConfig.from_server_args(server_args)
194
195

        self.is_generation = self.model_config.is_generation
196
        self.is_image_gen = self.model_config.is_image_gen
197
198
199
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id

200
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
201
            import_processors()
202
203
204
205
206
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
207
                use_fast=not server_args.disable_fast_image_processor,
208
209
210
            )

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
211
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
212
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
213
            self.mm_processor = get_mm_processor(
214
215
216
217
218
219
220
                self.model_config.hf_config, server_args, _processor
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
221
                self.tokenizer = get_tokenizer_from_processor(self.processor)
222
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
223
        else:
Mick's avatar
Mick committed
224
            self.mm_processor = get_dummy_processor()
225

226
227
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
228
229
230
231
232
233
234
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
235

236
        # Store states
237
        self.no_create_loop = False
238
        self.rid_to_state: Dict[str, ReqState] = {}
239
        self.health_check_failed = False
240
241
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
242
243
244
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
245
        self.log_request_metadata = self.get_log_request_metadata()
Lianmin Zheng's avatar
Lianmin Zheng committed
246

247
248
249
250
251
252
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
253

254
255
256
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

257
258
259
260
261
262
263
264
        # Set after scheduler is initialized
        self.max_req_input_len = None

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
265
                    # TODO: Add lora name/path in the future,
266
                },
267
268
269
270
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
271
272
273
            )

        # Communicators
274
275
276
277
278
279
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
280
281
282
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
283
284
285
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
286
287
288
289
290
291
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
292
293
294
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
295
296
297
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
298
299
300
        self.start_profile_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
301
        self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
302
303
304
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
305
306
307
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
308
309
310
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
311

312
        self._result_dispatcher = TypeBasedDispatcher(
313
            [
314
                (
315
316
317
318
319
320
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
321
                    self._handle_batch_output,
322
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
323
                (AbortReq, self._handle_abort_req),
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
353
354
355
356
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
357
358
359
360
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
361
362
363
364
365
366
367
368
                (
                    ProfileReqOutput,
                    self.start_profile_communicator.handle_recv,
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
369
370
371
372
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
373
374
375
376
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
377
                (HealthCheckOutput, lambda x: None),
378
379
380
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
381
        # For pd disaggregtion
Byron Hsu's avatar
Byron Hsu committed
382
383
384
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
385
386
387
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
388
        # Start kv boostrap server on prefill
Byron Hsu's avatar
Byron Hsu committed
389
390
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
391
392
393
394
            kv_bootstrap_server_class = get_kv_class(
                self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
            )
            self.bootstrap_server = kv_bootstrap_server_class(
Byron Hsu's avatar
Byron Hsu committed
395
396
397
                self.server_args.disaggregation_bootstrap_port
            )

398
    async def generate_request(
399
        self,
400
        obj: Union[GenerateReqInput, EmbeddingReqInput],
401
        request: Optional[fastapi.Request] = None,
402
    ):
403
404
        created_time = time.time()

405
        self.auto_create_handle_loop()
406
407
408
409
410
411
412
413
414
415

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        obj.normalize_batch_and_arguments()

        if self.log_requests:
416
            max_length, skip_names, _ = self.log_request_metadata
417
            logger.info(
418
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
419
420
            )

421
        async with self.model_update_lock.reader_lock:
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)

Mick's avatar
Mick committed
462
        image_inputs: Dict = await self.mm_processor.process_mm_data_async(
463
464
465
466
            image_data=obj.image_data,
            input_text=input_text or input_ids,
            request_obj=obj,
            max_req_input_len=self.max_req_input_len,
467
468
469
        )
        if image_inputs and "input_ids" in image_inputs:
            input_ids = image_inputs["input_ids"]
470
471
472
473
474
475
476
477
478
479

        self._validate_token_len(obj, input_ids)
        return self._create_tokenized_object(
            obj, input_text, input_ids, input_embeds, image_inputs
        )

    def _validate_token_len(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
480
481

        input_token_num = len(input_ids) if input_ids is not None else 0
482
        # Check if input alone exceeds context length
483
484
485
486
487
488
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

489
490
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
491
        if (
492
493
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
494
        ):
495
496
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
497
                f"Requested token count exceeds the model's maximum context length "
498
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
499
                f"tokens: {input_token_num} tokens from the input messages and "
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
        image_inputs: Optional[Dict] = None,
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""

        if self.is_generation:
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
            token_ids_logprob = obj.token_ids_logprob
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
522
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
523
524
525
526
527
528
529
530
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )
531

532
533
534
535
536
537
538
539
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
                image_inputs,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
554
                token_ids_logprob,
555
                obj.stream,
556
                bootstrap_host=obj.bootstrap_host,
557
                bootstrap_port=obj.bootstrap_port,
558
                bootstrap_room=obj.bootstrap_room,
559
560
561
562
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
563
                return_hidden_states=obj.return_hidden_states,
564
565
566
567
568
569
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
570
                image_inputs,
571
572
573
574
575
                sampling_params,
            )

        return tokenized_obj

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
            if self.is_generation and obj[i].image_data:
                raise ValueError(
                    "For image input processing do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

620
621
622
623
624
625
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
626
        self.send_to_scheduler.send_pyobj(tokenized_obj)
627
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        self.rid_to_state[obj.rid] = state

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]

        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
643
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
644
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
645
                    # Use exception to kill the whole call stack and asyncio task
646
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
647
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
648
                    )
649
650
651
652
653
654
655
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
656
657
658
659
660
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
                    logger.info(msg)

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
681
                    # Abort the request for disconnected requests (non-streaming, running)
682
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
683
                    # Use exception to kill the whole call stack and asyncio task
684
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
685
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
686
                    )
687
688
689
690
691
692
693
694
695
696
697
698

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
                for i in range(batch_size):
                    tmp_obj = obj[i]
                    tokenized_obj = await self._tokenize_one_request(tmp_obj)
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
777

778
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
779
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
780

781
    def abort_request(self, rid: str):
782
783
784
785
        if rid not in self.rid_to_state:
            return
        req = AbortReq(rid)
        self.send_to_scheduler.send_pyobj(req)
786

787
788
789
790
791
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
792
793
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
794
795
796
797
798
799
    ):
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
            num_steps=num_steps,
            activities=activities,
800
801
            with_stack=with_stack,
            record_shapes=record_shapes,
802
            profile_id=str(time.time()),
803
804
805
806
807
        )
        result = (await self.start_profile_communicator(req))[0]
        if not result.success:
            raise RuntimeError(result.message)
        return result
808
809

    def stop_profile(self):
810
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
811
812
        self.send_to_scheduler.send_pyobj(req)

813
814
    async def start_expert_distribution_record(self):
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
815

816
817
    async def stop_expert_distribution_record(self):
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
818

819
820
    async def dump_expert_distribution_record(self):
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
821

Chayenne's avatar
Chayenne committed
822
823
824
825
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
826
    ) -> Tuple[bool, str]:
827
        self.auto_create_handle_loop()
828
829
830
831

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
832
        logger.info("Start update_weights. Load format=%s", obj.load_format)
833

834
835
836
837
838
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
839

840
841
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
842
    ) -> Tuple[bool, str]:
843
844
845
846
847
848
849
850
851
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
852
            return result.success, result.message, result.num_paused_requests
853
854
855
856
857
858
859
860
861
862
863
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
864
865
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
866

867
868
869
870
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
871
    ) -> Tuple[bool, str]:
872
        self.auto_create_handle_loop()
873
874
875
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
876
        result = (await self.init_weights_update_group_communicator(obj))[0]
877
878
879
880
881
882
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
883
    ) -> Tuple[bool, str]:
884
885
886
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
887
        ), "dp_size must be 1 for update weights from distributed"
888

889
890
891
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
892
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
893
            return result.success, result.message
894

895
896
897
898
899
900
901
902
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
903
        ), "dp_size must be 1 for update weights from distributed"
904
905
906
907
908
909
910

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

911
912
913
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
914
915
916
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
917
        if self.server_args.dp_size == 1:
918
            return all_parameters[0]
919
920
921
        else:
            return all_parameters

922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

938
939
940
941
942
943
944
945
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

946
947
948
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
949
        self.auto_create_handle_loop()
950

951
952
953
954
955
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

956
        self.send_to_scheduler.send_pyobj(obj)
957
958
959
960

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
961
962
963
964
965
966
967
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

968
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
969
        req = GetInternalStateReq()
970
        responses: List[GetInternalStateReqOutput] = (
971
972
            await self.get_internal_state_communicator(req)
        )
973
974
        # Many DP ranks
        return [res.internal_state for res in responses]
975

976
977
978
979
980
981
982
983
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
                    ]
                )
            elif self.log_requests_level == 1:
                max_length = 2048
            elif self.log_requests_level == 2:
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1017
    def configure_logging(self, obj: ConfigureLoggingReq):
1018
1019
1020
1021
1022
1023
1024
1025
1026
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")
1027
        self.log_request_metadata = self.get_log_request_metadata()
1028

Lianmin Zheng's avatar
Lianmin Zheng committed
1029
    def create_abort_task(self, obj: GenerateReqInput):
1030
1031
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
            await asyncio.sleep(2)
1033
1034
1035
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1036
                for rid in obj.rid:
1037
1038
1039
1040
1041
1042
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1043
    def auto_create_handle_loop(self):
1044
        if self.no_create_loop:
1045
1046
            return

1047
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1048
        loop = asyncio.get_event_loop()
1049
1050
1051
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1052

1053
1054
        self.event_loop = loop

1055
1056
1057
1058
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1059
1060
1061
1062
1063
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1064
1065
1066
1067
1068
1069
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1070
1071
1072
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1073
1074
1075

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1076
            await asyncio.sleep(5)
1077

1078
        # Drain requests
1079
        while True:
1080
            remain_num_req = len(self.rid_to_state)
1081
1082
1083
1084
1085
1086
1087
1088
1089

            if self.health_check_failed:
                # if health check failed, we should exit immediately
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break

1090
            logger.info(
1091
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1092
1093
1094
1095
1096
1097
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

1098
        kill_process_tree(os.getpid(), include_parent=True)
1099
        sys.exit(0)
1100

Lianmin Zheng's avatar
Lianmin Zheng committed
1101
    async def handle_loop(self):
1102
1103
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
1104
        while True:
1105
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1106
            self._result_dispatcher(recv_obj)
1107
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
1108

1109
    def _handle_batch_output(
1110
1111
1112
1113
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1114
1115
1116
1117
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1118
1119
1120
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1121
1122
                continue

1123
            # Build meta_info and return value
1124
1125
1126
1127
1128
1129
1130
1131
1132
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1133
                    state,
1134
                    state.obj.top_logprobs_num,
1135
                    state.obj.token_ids_logprob,
1136
1137
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1150
            if getattr(recv_obj, "output_hidden_states", None):
1151
1152
1153
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1154
                state.text += recv_obj.output_strs[i]
1155
                out_dict = {
1156
                    "text": state.text,
1157
1158
1159
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1160
                if self.server_args.stream_output and state.obj.stream:
1161
1162
1163
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1164
                else:
1165
1166
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids
1167

1168
                out_dict = {
1169
                    "output_ids": output_token_ids,
1170
1171
                    "meta_info": meta_info,
                }
1172
            elif isinstance(recv_obj, BatchMultimodalOut):
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
                if isinstance(recv_obj.outputs[i], str):
                    out_dict = {
                        "text": recv_obj.outputs[i],
                        "meta_info": meta_info,
                    }
                else:
                    out_dict = {
                        "outputs": json.dumps(recv_obj.outputs[i]),
                        "meta_info": meta_info,
                    }
1183
1184
1185
1186
1187
1188
1189
1190
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1191
1192
1193
1194
1195
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
1196
                del self.rid_to_state[rid]
1197
1198

            state.out_list.append(out_dict)
1199
1200
            state.event.set()

1201
            # Log metrics and dump
1202
1203
1204
1205
1206
1207
1208
1209
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)

    def convert_logprob_style(
        self,
        meta_info: dict,
1210
        state: ReqState,
1211
        top_logprobs_num: int,
1212
        token_ids_logprob: List[int],
1213
1214
1215
1216
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1230
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1231
1232
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1233
1234
1235
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1236
1237
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1238
1239
1240
1241
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1255
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1256
1257
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1258
1259
1260
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1261
1262
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1263
1264
1265
                return_text_in_logprobs,
            )

1266
        if token_ids_logprob is not None:
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1280
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1281
1282
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1283
1284
1285
1286
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1287
1288
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1289
1290
1291
1292
                    return_text_in_logprobs,
                )
            )

1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1336
1337
1338
        if state.first_token_time == 0.0:
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1339
1340
1341
1342
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1343
1344
1345
1346
1347
1348
1349
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1350
                )
1351
1352
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1353
1354

        if state.finished:
1355
1356
1357
1358
1359
1360
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1361
            self.metrics_collector.observe_one_finished_request(
1362
1363
                recv_obj.prompt_tokens[i],
                completion_tokens,
1364
                recv_obj.cached_tokens[i],
1365
                state.finished_time - state.created_time,
1366
                has_grammar,
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

Lianmin Zheng's avatar
Lianmin Zheng committed
1392
1393
1394
    def _handle_abort_req(self, recv_obj):
        self.rid_to_state.pop(recv_obj.rid)

1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1405
            # set future if the all results are received
1406
1407
1408
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1409

1410
1411
1412
1413
1414
1415
1416
1417
1418
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1419
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1420
1421
1422
1423
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1424
class SignalHandler:
1425
    def __init__(self, tokenizer_manager: TokenizerManager):
1426
        self.tokenizer_manager = tokenizer_manager
1427

1428
    def sigterm_handler(self, signum=None, frame=None):
1429
1430
1431
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1432
        self.tokenizer_manager.gracefully_exit = True
1433

1434
1435
1436
1437
1438
1439
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
        kill_process_tree(os.getpid())

1440
1441
1442
1443
1444

T = TypeVar("T")


class _Communicator(Generic[T]):
1445
1446
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1447
1448
1449
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1450
        self._result_event: Optional[asyncio.Event] = None
1451
        self._result_values: Optional[List[T]] = None
1452
        self._ready_queue: Deque[asyncio.Future] = deque()
1453
1454

    async def __call__(self, obj):
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1466
        self._result_values = []
1467
        await self._result_event.wait()
1468
        result_values = self._result_values
1469
1470
1471
1472
1473
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1474
1475
1476
1477
1478
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1479
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#