tokenizer_manager.py 63.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
21
import math
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
24
25
import signal
import sys
26
import threading
27
import time
28
import uuid
29
from collections import deque
30
31
from datetime import datetime
from http import HTTPStatus
32
33
34
35
36
37
38
39
40
41
42
43
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
44

45
import fastapi
46
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
import uvloop
import zmq
import zmq.asyncio
50
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
from sglang.srt.aio_rwlock import RWLock
53
from sglang.srt.configs.model_config import ModelConfig
54
55
56
57
58
59
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
65
from sglang.srt.managers.io_struct import (
66
    AbortReq,
67
    BatchEmbeddingOut,
68
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
69
    BatchStrOut,
70
    BatchTokenIDOut,
71
    CloseSessionReqInput,
72
    ConfigureLoggingReq,
73
    EmbeddingReqInput,
74
    ExpertDistributionReq,
75
    ExpertDistributionReqOutput,
76
77
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    GenerateReqInput,
79
80
    GetInternalStateReq,
    GetInternalStateReqOutput,
81
82
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
83
    HealthCheckOutput,
84
85
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
86
87
    OpenSessionReqInput,
    OpenSessionReqOutput,
88
    ProfileReq,
89
90
    ProfileReqOutput,
    ProfileReqType,
91
92
93
94
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
95
    SessionParams,
96
97
    SetInternalStateReq,
    SetInternalStateReqOutput,
98
99
    SlowDownReqInput,
    SlowDownReqOutput,
100
101
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
102
103
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
104
105
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
106
107
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
108
)
Mick's avatar
Mick committed
109
110
111
112
113
from sglang.srt.managers.multimodal_processor import (
    get_dummy_processor,
    get_mm_processor,
    import_processors,
)
114
115
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
116
from sglang.srt.server_args import PortArgs, ServerArgs
117
118
from sglang.srt.utils import (
    dataclass_to_string_truncated,
119
    get_bool_env_var,
120
121
122
    get_zmq_socket,
    kill_process_tree,
)
123
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
126

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

127
128
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
131
132
133
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

134
    out_list: List[Dict[Any, Any]]
135
136
    finished: bool
    event: asyncio.Event
137
    obj: Union[GenerateReqInput, EmbeddingReqInput]
138
139
140

    # For metrics
    created_time: float
141
142
143
144
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
145
146
147

    # For streaming output
    last_output_offset: int = 0
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    # For incremental state update.
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
163
164


165
166
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
167

Lianmin Zheng's avatar
Lianmin Zheng committed
168
169
170
171
172
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
173
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
174
        self.server_args = server_args
175
        self.enable_metrics = server_args.enable_metrics
176
        self.log_requests = server_args.log_requests
177
        self.log_requests_level = server_args.log_requests_level
178
179
180
181
182
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
183

184
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
185
        context = zmq.asyncio.Context(2)
186
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
187
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
188
189
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
190
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
191
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
192

193
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
194
        self.model_path = server_args.model_path
195
        self.served_model_name = server_args.served_model_name
196
        self.model_config = ModelConfig.from_server_args(server_args)
197
198

        self.is_generation = self.model_config.is_generation
199
        self.is_image_gen = self.model_config.is_image_gen
200
201
202
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id

203
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
204
            import_processors()
205
206
207
208
209
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
210
                use_fast=not server_args.disable_fast_image_processor,
211
212
213
            )

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
214
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
215
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
216
            self.mm_processor = get_mm_processor(
217
218
219
220
221
222
223
                self.model_config.hf_config, server_args, _processor
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
224
                self.tokenizer = get_tokenizer_from_processor(self.processor)
225
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
226
        else:
227
            self.mm_processor = None
228

229
230
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
231
232
233
234
235
236
237
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
238

239
        # Store states
240
        self.no_create_loop = False
241
        self.rid_to_state: Dict[str, ReqState] = {}
242
        self.health_check_failed = False
243
244
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
245
246
247
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
248
        self.log_request_metadata = self.get_log_request_metadata()
Lianmin Zheng's avatar
Lianmin Zheng committed
249

250
251
252
253
254
255
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
256

257
258
259
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

260
261
262
263
264
265
266
267
        # Set after scheduler is initialized
        self.max_req_input_len = None

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
268
                    # TODO: Add lora name/path in the future,
269
                },
270
271
272
273
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
274
275
276
            )

        # Communicators
277
278
279
280
281
282
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
283
284
285
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
286
287
288
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
289
290
291
292
293
294
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
295
296
297
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
298
299
300
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
301
        self.profile_communicator = _Communicator(
302
303
            self.send_to_scheduler, server_args.dp_size
        )
304
        self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
305
306
307
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
308
309
310
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
311
312
313
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
314

315
        self._result_dispatcher = TypeBasedDispatcher(
316
            [
317
                (
318
319
320
321
322
323
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
324
                    self._handle_batch_output,
325
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
326
                (AbortReq, self._handle_abort_req),
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
356
357
358
359
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
360
361
362
363
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
364
365
                (
                    ProfileReqOutput,
366
                    self.profile_communicator.handle_recv,
367
368
369
370
371
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
372
373
374
375
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
376
377
378
379
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
380
                (HealthCheckOutput, lambda x: None),
381
382
383
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
384
        # For pd disaggregtion
Byron Hsu's avatar
Byron Hsu committed
385
386
387
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
388
389
390
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
391
        # Start kv boostrap server on prefill
Byron Hsu's avatar
Byron Hsu committed
392
393
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
394
395
396
397
            kv_bootstrap_server_class = get_kv_class(
                self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
            )
            self.bootstrap_server = kv_bootstrap_server_class(
Byron Hsu's avatar
Byron Hsu committed
398
399
400
                self.server_args.disaggregation_bootstrap_port
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
401
402
403
        self.current_load = 0
        self.current_load_lock = asyncio.Lock()

404
    async def generate_request(
405
        self,
406
        obj: Union[GenerateReqInput, EmbeddingReqInput],
407
        request: Optional[fastapi.Request] = None,
408
    ):
409
410
        created_time = time.time()

411
        self.auto_create_handle_loop()
412
413
414
415
416
417
418
419
420

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        obj.normalize_batch_and_arguments()

421
422
423
424
425
426
427
428
429
430
431
432
433
434
        if isinstance(obj, GenerateReqInput):
            return_hidden_states = obj.return_hidden_states
            has_return_hidden_states = return_hidden_states == True or (
                isinstance(return_hidden_states, list) and any(return_hidden_states)
            )
            if (
                not self.server_args.enable_return_hidden_states
                and has_return_hidden_states
            ):
                raise ValueError(
                    "return_hidden_states=True requires the server to be started "
                    "with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
                )

435
        if self.log_requests:
436
            max_length, skip_names, _ = self.log_request_metadata
437
            logger.info(
438
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
439
440
            )

441
        async with self.model_update_lock.reader_lock:
442
443
444
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
445
446
                state = self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, state, request):
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)

482
        if self.mm_processor and obj.contains_mm_input():
483
484
485
486
487
488
489
490
            image_inputs = await self.mm_processor.process_mm_data_async(
                image_data=obj.image_data,
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
491
492
        else:
            image_inputs: Optional[Dict] = None
493
494
495
496
497
498
499
500
501
502

        self._validate_token_len(obj, input_ids)
        return self._create_tokenized_object(
            obj, input_text, input_ids, input_embeds, image_inputs
        )

    def _validate_token_len(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
503
504

        input_token_num = len(input_ids) if input_ids is not None else 0
505
        # Check if input alone exceeds context length
506
507
508
509
510
511
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

512
513
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
514
        if (
515
516
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
517
        ):
518
519
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
520
                f"Requested token count exceeds the model's maximum context length "
521
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
522
                f"tokens: {input_token_num} tokens from the input messages and "
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
        image_inputs: Optional[Dict] = None,
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""

        if self.is_generation:
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
            token_ids_logprob = obj.token_ids_logprob
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
545
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
546
547
548
549
550
551
552
553
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )
554

555
556
557
558
559
560
561
562
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
                image_inputs,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
577
                token_ids_logprob,
578
                obj.stream,
579
                bootstrap_host=obj.bootstrap_host,
580
                bootstrap_port=obj.bootstrap_port,
581
                bootstrap_room=obj.bootstrap_room,
582
583
584
585
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
586
                return_hidden_states=obj.return_hidden_states,
587
                data_parallel_rank=obj.data_parallel_rank,
588
589
590
591
592
593
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
594
                image_inputs,
595
596
597
598
599
                sampling_params,
            )

        return tokenized_obj

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
            if self.is_generation and obj[i].image_data:
                raise ValueError(
                    "For image input processing do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

644
645
646
647
648
649
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
650
        self.send_to_scheduler.send_pyobj(tokenized_obj)
651
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
652
        self.rid_to_state[obj.rid] = state
653
        return state
654
655
656
657

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
658
        state: ReqState,
659
660
661
662
663
664
665
666
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
667
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
668
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
669
                    # Use exception to kill the whole call stack and asyncio task
670
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
671
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
672
                    )
673
674
675
676
677
678
679
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
680
681
682
683
684
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
                    logger.info(msg)

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
705
                    # Abort the request for disconnected requests (non-streaming, running)
706
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
707
                    # Use exception to kill the whole call stack and asyncio task
708
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
709
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
710
                    )
711
712
713
714
715
716
717
718
719
720
721
722

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
723
724
725
726
727
728
729
730
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
731
732
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
733
734
735
736
737
738
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
                for i in range(batch_size):
                    tmp_obj = obj[i]
                    tokenized_obj = await self._tokenize_one_request(tmp_obj)
739
740
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
741
                    rids.append(tmp_obj.rid)
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
765
766
                state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, state, request).__anext__()
767
768
769
770
771
772
773

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
774
775
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
801

802
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
803
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
804

805
    def abort_request(self, rid: str):
806
807
808
809
        if rid not in self.rid_to_state:
            return
        req = AbortReq(rid)
        self.send_to_scheduler.send_pyobj(req)
810

811
812
813
        if self.enable_metrics:
            self.metrics_collector.observe_one_aborted_request()

814
815
816
817
818
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
819
820
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
821
        profile_by_stage: bool = False,
822
    ):
823
        self.auto_create_handle_loop()
824
825
        env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
        with_stack = False if with_stack is False or env_with_stack is False else True
826
827
828
829
830
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
            num_steps=num_steps,
            activities=activities,
831
832
            with_stack=with_stack,
            record_shapes=record_shapes,
833
            profile_by_stage=profile_by_stage,
834
            profile_id=str(time.time()),
835
        )
836
837
838
        return await self._execute_profile(req)

    async def stop_profile(self):
839
        self.auto_create_handle_loop()
840
841
842
843
844
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
        return await self._execute_profile(req)

    async def _execute_profile(self, req: ProfileReq):
        result = (await self.profile_communicator(req))[0]
845
846
847
        if not result.success:
            raise RuntimeError(result.message)
        return result
848

849
    async def start_expert_distribution_record(self):
850
        self.auto_create_handle_loop()
851
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
852

853
    async def stop_expert_distribution_record(self):
854
        self.auto_create_handle_loop()
855
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
856

857
    async def dump_expert_distribution_record(self):
858
        self.auto_create_handle_loop()
859
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
860

Chayenne's avatar
Chayenne committed
861
862
863
864
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
865
    ) -> Tuple[bool, str]:
866
        self.auto_create_handle_loop()
867
868
869
870

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
871
        logger.info("Start update_weights. Load format=%s", obj.load_format)
872

873
        if True:  # Keep this redundant check to simplify some internal code sync
874
875
876
877
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
878

879
880
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
881
    ) -> Tuple[bool, str]:
882
883
884
885
886
887
888
889
890
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
891
            return result.success, result.message, result.num_paused_requests
892
893
894
895
896
897
898
899
900
901
902
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
903
904
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
905

906
907
908
909
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
910
    ) -> Tuple[bool, str]:
911
        self.auto_create_handle_loop()
912
913
914
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
915
        result = (await self.init_weights_update_group_communicator(obj))[0]
916
917
918
919
920
921
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
922
    ) -> Tuple[bool, str]:
923
924
        self.auto_create_handle_loop()
        assert (
925
926
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
927

928
929
930
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
931
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
932
            return result.success, result.message
933

934
935
936
937
938
939
940
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
941
942
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
943
944
945
946
947
948
949

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

950
951
952
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
953
954
955
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
956
        if self.server_args.dp_size == 1:
957
            return all_parameters[0]
958
959
960
        else:
            return all_parameters

961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

977
978
979
980
981
982
983
984
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

985
986
987
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
988
        self.auto_create_handle_loop()
989

990
991
992
993
994
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

995
        self.send_to_scheduler.send_pyobj(obj)
996
997
998
999

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
1000
1001
1002
1003
1004
1005
1006
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

1007
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
1008
        req = GetInternalStateReq()
1009
        responses: List[GetInternalStateReqOutput] = (
1010
1011
            await self.get_internal_state_communicator(req)
        )
1012
1013
        # Many DP ranks
        return [res.internal_state for res in responses]
1014

Liangsheng Yin's avatar
Liangsheng Yin committed
1015
1016
1017
1018
1019
1020
1021
1022
    async def get_load(self) -> dict:
        # TODO(lsyin): fake load report server
        if not self.current_load_lock.locked():
            async with self.current_load_lock:
                internal_state = await self.get_internal_state()
                self.current_load = internal_state[0]["load"]
        return {"load": self.current_load}

1023
1024
1025
1026
1027
1028
1029
1030
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
                    ]
                )
            elif self.log_requests_level == 1:
                max_length = 2048
            elif self.log_requests_level == 2:
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1064
    def configure_logging(self, obj: ConfigureLoggingReq):
1065
1066
1067
1068
1069
1070
1071
1072
1073
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")
1074
        self.log_request_metadata = self.get_log_request_metadata()
1075

Lianmin Zheng's avatar
Lianmin Zheng committed
1076
    def create_abort_task(self, obj: GenerateReqInput):
1077
1078
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1079
            await asyncio.sleep(2)
1080
1081
1082
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1083
                for rid in obj.rid:
1084
1085
1086
1087
1088
1089
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1090
    def auto_create_handle_loop(self):
1091
        if self.no_create_loop:
1092
1093
            return

1094
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
        loop = asyncio.get_event_loop()
1096
1097
1098
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1099

1100
1101
        self.event_loop = loop

1102
1103
1104
1105
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1106
1107
1108
1109
1110
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1111
1112
1113
1114
1115
1116
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1117
1118
1119
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1120
1121
1122

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1123
            await asyncio.sleep(5)
1124

1125
        # Drain requests
1126
        while True:
1127
            remain_num_req = len(self.rid_to_state)
1128
1129
1130
1131
1132
1133
1134
1135
1136

            if self.health_check_failed:
                # if health check failed, we should exit immediately
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break

1137
            logger.info(
1138
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1139
1140
1141
1142
1143
1144
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

1145
        kill_process_tree(os.getpid(), include_parent=True)
1146
        sys.exit(0)
1147

Lianmin Zheng's avatar
Lianmin Zheng committed
1148
    async def handle_loop(self):
1149
1150
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
1151
        while True:
1152
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1153
            self._result_dispatcher(recv_obj)
1154
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
1155

1156
    def _handle_batch_output(
1157
1158
1159
1160
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1161
1162
1163
1164
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1165
1166
1167
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1168
1169
                continue

1170
            # Build meta_info and return value
1171
1172
1173
1174
1175
1176
1177
1178
1179
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1180
                    state,
1181
                    state.obj.top_logprobs_num,
1182
                    state.obj.token_ids_logprob,
1183
1184
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1197
            if getattr(recv_obj, "output_hidden_states", None):
1198
1199
1200
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1201
                state.text += recv_obj.output_strs[i]
1202
                out_dict = {
1203
                    "text": state.text,
1204
1205
1206
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1207
                if self.server_args.stream_output and state.obj.stream:
1208
1209
1210
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1211
                else:
1212
1213
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids
1214

1215
                out_dict = {
1216
                    "output_ids": output_token_ids,
1217
1218
                    "meta_info": meta_info,
                }
1219
            elif isinstance(recv_obj, BatchMultimodalOut):
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
                if isinstance(recv_obj.outputs[i], str):
                    out_dict = {
                        "text": recv_obj.outputs[i],
                        "meta_info": meta_info,
                    }
                else:
                    out_dict = {
                        "outputs": json.dumps(recv_obj.outputs[i]),
                        "meta_info": meta_info,
                    }
1230
1231
1232
1233
1234
1235
1236
1237
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1238
1239
1240
1241
1242
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
1243
                del self.rid_to_state[rid]
1244
1245

            state.out_list.append(out_dict)
1246
1247
            state.event.set()

1248
            # Log metrics and dump
1249
1250
1251
1252
1253
1254
1255
1256
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)

    def convert_logprob_style(
        self,
        meta_info: dict,
1257
        state: ReqState,
1258
        top_logprobs_num: int,
1259
        token_ids_logprob: List[int],
1260
1261
1262
1263
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1277
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1278
1279
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1280
1281
1282
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1283
1284
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1285
1286
1287
1288
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1302
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1303
1304
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1305
1306
1307
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1308
1309
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1310
1311
1312
                return_text_in_logprobs,
            )

1313
        if token_ids_logprob is not None:
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1327
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1328
1329
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1330
1331
1332
1333
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1334
1335
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1336
1337
1338
1339
                    return_text_in_logprobs,
                )
            )

1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1383
1384
1385
        if state.first_token_time == 0.0:
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1386
1387
1388
1389
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1390
1391
1392
1393
1394
1395
1396
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1397
                )
1398
1399
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1400
1401

        if state.finished:
1402
1403
1404
1405
1406
1407
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1408
            self.metrics_collector.observe_one_finished_request(
1409
1410
                recv_obj.prompt_tokens[i],
                completion_tokens,
1411
                recv_obj.cached_tokens[i],
1412
                state.finished_time - state.created_time,
1413
                has_grammar,
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

Lianmin Zheng's avatar
Lianmin Zheng committed
1439
    def _handle_abort_req(self, recv_obj):
1440
        self.rid_to_state.pop(recv_obj.rid, None)
Lianmin Zheng's avatar
Lianmin Zheng committed
1441

1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1452
            # set future if the all results are received
1453
1454
1455
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
    async def score_request(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
        request: Optional[Any] = None,
    ) -> List[List[float]]:
        """
        See Engine.score() for more details.
        """
        if label_token_ids is None:
            raise ValueError("label_token_ids must be provided")

        if self.tokenizer is not None:
            vocab_size = self.tokenizer.vocab_size
            for token_id in label_token_ids:
                if token_id >= vocab_size:
                    raise ValueError(
                        f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
                    )

        # Handle string or tokenized query/items
        if isinstance(query, str) and (
            isinstance(items, str)
            or (isinstance(items, list) and (not items or isinstance(items[0], str)))
        ):
            # Both query and items are text
            items_list = [items] if isinstance(items, str) else items
            if item_first:
                prompts = [f"{item}{query}" for item in items_list]
            else:
                prompts = [f"{query}{item}" for item in items_list]
            batch_request = GenerateReqInput(
                text=prompts,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        elif (
            isinstance(query, list)
            and isinstance(items, list)
            and items
            and isinstance(items[0], list)
        ):
            # Both query and items are token IDs
            if item_first:
                input_ids_list = [item + query for item in items]
            else:
                input_ids_list = [query + item for item in items]
            batch_request = GenerateReqInput(
                input_ids=input_ids_list,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        else:
            raise ValueError(
                "Invalid combination of query/items types for score_request."
            )

        results = await self.generate_request(batch_request, request).__anext__()
        scores = []

        for result in results:
            # Get logprobs for each token
            logprobs = {}
            for logprob, token_id, _ in result["meta_info"].get(
                "output_token_ids_logprobs", []
            )[0]:
                if token_id in label_token_ids:
                    logprobs[token_id] = logprob

            # Get scores in order of label_token_ids
            score_list = [
                logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
            ]

            # Apply softmax to logprobs if needed
            if apply_softmax:
                score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
            else:
                # Convert logprobs to probabilities if not using softmax
                score_list = [
                    math.exp(x) if x != float("-inf") else 0.0 for x in score_list
                ]

            scores.append(score_list)

        return scores

1550

1551
1552
1553
1554
1555
1556
1557
1558
1559
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1560
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1561
1562
1563
1564
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1565
class SignalHandler:
1566
    def __init__(self, tokenizer_manager: TokenizerManager):
1567
        self.tokenizer_manager = tokenizer_manager
1568

1569
    def sigterm_handler(self, signum=None, frame=None):
1570
1571
1572
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1573
        self.tokenizer_manager.gracefully_exit = True
1574

1575
1576
1577
1578
1579
1580
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
        kill_process_tree(os.getpid())

1581
1582
1583
1584
1585

T = TypeVar("T")


class _Communicator(Generic[T]):
1586
1587
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1588
1589
1590
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1591
        self._result_event: Optional[asyncio.Event] = None
1592
        self._result_values: Optional[List[T]] = None
1593
        self._ready_queue: Deque[asyncio.Future] = deque()
1594
1595

    async def __call__(self, obj):
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1607
        self._result_values = []
1608
        await self._result_event.wait()
1609
        result_values = self._result_values
1610
1611
1612
1613
1614
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1615
1616
1617
1618
1619
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1620
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#