scheduler.py 91.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
35
import zmq

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
42
43
44
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
46
    CloseSessionReqInput,
47
    FlushCacheReq,
48
49
    GetInternalStateReq,
    GetInternalStateReqOutput,
50
51
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
52
    HealthCheckOutput,
53
54
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
55
56
    OpenSessionReqInput,
    OpenSessionReqOutput,
57
    ProfileReq,
58
59
    ProfileReqOutput,
    ProfileReqType,
60
61
62
63
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
64
65
    SetInternalStateReq,
    SetInternalStateReqOutput,
66
67
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
68
69
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
70
71
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
72
73
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
74
75
76
77
78
79
80
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
81
    global_server_args_dict,
82
)
83
84
85
86
87
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
88
from sglang.srt.managers.session_controller import Session
89
from sglang.srt.managers.tp_worker import TpModelWorker
90
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
91
from sglang.srt.managers.utils import validate_input_length
92
from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
94
from sglang.srt.mem_cache.radix_cache import RadixCache
95
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
96
from sglang.srt.model_executor.forward_batch_info import ForwardMode
97
from sglang.srt.server_args import PortArgs, ServerArgs
98
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
100
101
102
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
103
    crash_on_warnings,
104
    get_bool_env_var,
105
    get_zmq_socket,
106
    pyspy_dump_schedulers,
107
    set_gpu_proc_affinity,
108
109
110
    set_random_seed,
    suppress_other_loggers,
)
111
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
112
113
114

logger = logging.getLogger(__name__)

115
# Test retract decode for debugging purposes
116
117
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
118

119

120
121
122
123
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
124
125
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
126
127
128
129
130
131
132
133
134
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


135
136
137
138
139
140
141
142
143
class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
144
        dp_rank: Optional[int],
145
146
    ):
        # Parse args
147
        self.server_args = server_args
148
149
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
150
151
152
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
153
        self.enable_overlap = not server_args.disable_overlap_schedule
154
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
155
        self.enable_metrics = server_args.enable_metrics
156
        self.stream_interval = server_args.stream_interval
157
158
159
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
160
161
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
162

163
        # Distributed rank info
164
165
166
167
168
169
170
171
172
173
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

174
175
        # Init inter-process communication
        context = zmq.Context(2)
176
        if self.attn_tp_rank == 0:
177
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
178
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
179
            )
180
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
181
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
182
            )
183

184
            if server_args.skip_tokenizer_init:
185
                # Directly send to the TokenizerManager
186
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
187
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
188
189
                )
            else:
190
                # Send to the DetokenizerManager
191
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
192
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
193
                )
194
        else:
195
            self.recv_from_tokenizer = None
196
197
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
198
199

        # Init tokenizer
200
        self.init_tokenizer()
201

202
203
204
205
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
206
207
208
209
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

210
        # Launch a tensor parallel worker
211
        if self.enable_overlap:
212
            TpWorkerClass = TpModelWorkerClient
213
214
        else:
            TpWorkerClass = TpModelWorker
215

216
        self.tp_worker = TpWorkerClass(
217
            server_args=server_args,
218
219
            gpu_id=gpu_id,
            tp_rank=tp_rank,
220
            dp_rank=dp_rank,
221
            nccl_port=port_args.nccl_port,
222
        )
223

224
        # Launch a draft worker for speculative decoding
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

239
        # Get token and memory info from the model worker
240
241
242
243
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
244
            self.max_req_len,
245
246
            self.max_req_input_len,
            self.random_seed,
247
            self.device,
248
249
250
251
252
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
253
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
254
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
255
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
256
        global_server_args_dict.update(worker_global_server_args_dict)
257
        set_random_seed(self.random_seed)
258

259
260
261
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
262
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
263
264
265
266
267
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

268
        # Init memory pool and cache
269
        self.init_memory_pool_and_cache()
270
271
272

        # Init running status
        self.waiting_queue: List[Req] = []
273
        self.staging_reqs = {}
274
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
275
        self.running_batch: Optional[ScheduleBatch] = None
276
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
277
        self.cur_batch: Optional[ScheduleBatch] = None
278
279
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
        self.forward_ct = 0
        self.forward_ct_decode = 0
282
        self.num_generated_tokens = 0
283
        self.last_decode_stats_tic = time.time()
284
        self.return_health_check_ct = 0
285
        self.current_stream = torch.get_device_module(self.device).current_stream()
286
287
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
288

289
        # Init session info
290
        self.sessions: Dict[str, Session] = {}
291
292
293

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
294
295
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
296
        self.chunked_req = None
297
298
299
300
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
301
        # Init the grammar backend for constrained generation
302
        self.grammar_queue: List[Req] = []
303
        if not server_args.skip_tokenizer_init:
304
305
306
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
        else:
            self.grammar_backend = None
309

310
311
        # Init schedule policy and new token estimation
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
312
313
314
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
315
316
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
317
318
            * server_args.schedule_conservativeness,
            1.0,
319
        )
320
321
322
323
324
325
326
327
328
329
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

330
        # Tell whether the current running batch is full so that we can skip
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
333
        self.batch_is_full = False
334

Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
337
338
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
339
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
340

341
        # Init memory saver
342
343
344
345
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

346
        # Init profiler
347
348
349
350
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.torch_profiler_activities: Optional[List[str]] = None
        self.profiler_target_forward_ct: Optional[int] = None
351

352
        # Init metrics stats
353
        self.init_metrics()
354

355
356
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
357
358
359
360
361
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
362
363
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
364
365
366
367
368
369
370
371
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
372
373
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
374
                (ProfileReq, self.profile),
375
                (GetInternalStateReq, self.get_internal_state),
376
                (SetInternalStateReq, self.set_internal_state),
377
378
379
            ]
        )

380
381
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
382

383
384
385
386
387
388
389
390
391
392
393
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )
        self.is_generation = self.model_config.is_generation
394

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
452
        )
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
474

475
    @torch.no_grad()
476
    def event_loop_normal(self):
477
        """A normal scheduler loop."""
478
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
481

482
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
483
            self.cur_batch = batch
484
485
486
487

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
488
            else:
489
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
490
                self.check_memory()
491
                self.new_token_ratio = self.init_new_token_ratio
492
493

            self.last_batch = batch
494

495
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
496
    def event_loop_overlap(self):
497
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
498
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
499
500
501
502
503
504
505

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
506

Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
            if batch:
                result = self.run_batch(batch)
509
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
510

511
                if self.last_batch is None:
512
                    # Create a dummy first batch to start the pipeline for overlap schedule.
513
514
515
516
517
518
519
520
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
521
            if self.last_batch:
522
                # Process the results of the last batch
523
                tmp_batch, tmp_result = self.result_queue.popleft()
524
525
526
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
527
528
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
529
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
530
                self.check_memory()
531
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
532
533
534

            self.last_batch = batch

535
536
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
537
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
538
539
            recv_reqs = []

540
541
542
543
544
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
545
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
546
547
        else:
            recv_reqs = None
548

549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
583
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
584
585
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
586
    def process_input_requests(self, recv_reqs: List):
587
        for recv_req in recv_reqs:
588
589
590
591
592
593
594
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
                self.chunked_req is not None or self.running_batch is not None
            ):
                self.return_health_check_ct += 1
                continue

595
            output = self._request_dispatcher(recv_req)
596
597
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
598
599
600
601
602

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
603
        # Create a new request
604
605
606
607
608
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
609
610
611
612
613
614
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

615
616
617
618
619
620
621
622
623
624
625
626
627
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

628
629
630
631
632
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
633
634
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
635
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
636
                stream=recv_req.stream,
637
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
638
                input_embeds=recv_req.input_embeds,
639
                custom_logit_processor=custom_logit_processor,
640
                return_hidden_states=recv_req.return_hidden_states,
641
                eos_token_ids=self.model_config.hf_eos_token_id,
642
643
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
644

645
646
647
648
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
649
                req.finished_reason = FINISH_ABORT(
650
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
651
                )
652
                self._add_request_to_queue(req)
653
654
                return
        else:
655
656
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
657
            req = session.create_req(recv_req, self.tokenizer)
658
            if isinstance(req.finished_reason, FINISH_ABORT):
659
                self._add_request_to_queue(req)
660
                return
661

662
        # Handle multimodal inputs
663
        if recv_req.image_inputs is not None:
664
665
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
666
            req.origin_input_ids = self.pad_input_ids_func(
667
                req.origin_input_ids, image_inputs
668
            )
669
            req.extend_image_inputs(image_inputs)
670

671
            if len(req.origin_input_ids) >= self.max_req_input_len:
672
                error_msg = (
673
                    "Multimodal prompt is too long after expanding multimodal tokens. "
674
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
675
                )
676
                logger.error(error_msg)
677
                req.origin_input_ids = [0]
678
                req.image_inputs = None
679
                req.sampling_params.max_new_tokens = 0
680
                req.finished_reason = FINISH_ABORT(
681
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
682
                )
683
                self._add_request_to_queue(req)
684
685
                return

686
687
688
689
690
691
692
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
693
694
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
695
            self._add_request_to_queue(req)
696
            return
697

698
        # Copy more attributes
699
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
700
701
702
703
704
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

705
706
707
708
709
710
711
712
713
714
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

715
716
717
718
719
720
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
721
            self.max_req_len - len(req.origin_input_ids) - 1,
722
723
        )

724
725
726
727
728
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
729
            or req.sampling_params.ebnf is not None
730
            or req.sampling_params.structural_tag is not None
731
732
733
734
735
736
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
737
738
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
739
740
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
741
742
743
744
745
746
747

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
748
749
            self.grammar_queue.append(req)
        else:
750
751
752
753
754
755
756
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
        self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req]):
        self.waiting_queue.extend(reqs)
757
758
759

    def handle_embedding_request(
        self,
760
        recv_req: TokenizedEmbeddingReqInput,
761
762
763
764
765
766
767
768
769
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
                req.image_inputs = None
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
                self.waiting_queue.append(req)
                return

794
        # Validate prompts length
795
        error_msg = validate_input_length(
796
797
798
799
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
800
        if error_msg:
801
            self._add_request_to_queue(req)
802
            return
803

804
805
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
806
        self._add_request_to_queue(req)
807

808
809
810
811
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
812
        running_bs: int,
813
    ):
814
        num_used = self.max_total_num_tokens - (
815
816
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
817
        )
818
819
820
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
821

822
        f = (
823
824
825
826
827
828
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
829
            f"#queue-req: {len(self.waiting_queue)}, "
830
        )
831
        logger.info(f)
832
833

        if self.enable_metrics:
834
835
836
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
837
838
839
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
840
841
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
842
843
844
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
845
846
847
848
849
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
850
        num_used = self.max_total_num_tokens - (
851
852
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
853
        )
854
855
856
857
858

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
859

860
861
862
863
864
865
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
866
867
868
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
869
            )
870
            spec_accept_length = 0
871
        else:
872
            spec_accept_length = (
873
874
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
875
876
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
877
878
879
880
881
882
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
883
                f"accept len: {spec_accept_length:.2f}, "
884
885
886
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
887
888
889
            )

        logger.info(msg)
890
891
892
893
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
894
895
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
896
            self.stats.num_queue_reqs = len(self.waiting_queue)
897
            self.stats.spec_accept_length = spec_accept_length
898
899
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
900
901
    def check_memory(self):
        available_size = (
902
903
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
904
        )
905
906
907
908
909
910
911
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
912
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
913
                "KV cache pool leak detected!"
914
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
915
            )
916
917
918
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
919
920

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
921
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
922
                "Memory pool leak detected!"
923
924
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
925
            )
926
927
928
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
929

930
931
932
933
934
935
936
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
937
                self.token_to_kv_pool_allocator.available_size()
938
939
940
941
942
943
944
945
946
947
                + self.tree_cache.evictable_size()
            )
            num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

948
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
949
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
950
        if self.last_batch and self.last_batch.forward_mode.is_extend():
951
952
953
954
955
956
957
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
958
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
959

960
            last_bs = self.last_batch.batch_size()
961
            self.last_batch.filter_batch()
962
963
964
            if self.last_batch.batch_size() < last_bs:
                self.batch_is_full = False

965
966
967
968
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
969
                    # merge running_batch with prefill batch
970
                    self.running_batch.merge_batch(self.last_batch)
971

972
973
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
974
975
976
977
978
979
980
981
982
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
            if self.running_batch is None:
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch
983

984
985
986
987
988
        # Handle DP attention
        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret
989

Lianmin Zheng's avatar
Lianmin Zheng committed
990
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
991
        # Check if the grammar is ready in the grammar queue
992
        if self.grammar_queue:
993
            self.move_ready_grammar_requests()
994

Lianmin Zheng's avatar
Lianmin Zheng committed
995
996
997
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
998
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
999
1000
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1001
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
1002
        if running_bs >= self.max_running_requests:
1003
            self.batch_is_full = True
1004
1005
1006
1007
1008
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1009
        # Prefill policy
1010
1011
        adder = PrefillAdder(
            self.tree_cache,
1012
            self.token_to_kv_pool_allocator,
1013
1014
1015
1016
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1017
            running_bs if self.is_mixed_chunk else 0,
1018
1019
        )

1020
1021
1022
1023
        is_chunked = self.chunked_req is not None
        if is_chunked:
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1024

Lianmin Zheng's avatar
Lianmin Zheng committed
1025
        if self.lora_paths:
1026
1027
1028
1029
1030
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )
1031
        # Get requests from the waiting queue to a new prefill batch
1032
1033
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1034
                self.lora_paths
1035
1036
1037
1038
1039
1040
1041
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
1042
                self.batch_is_full = True
1043
1044
                break

1045
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1046
                self.batch_is_full = True
1047
                break
1048

1049
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073

            if self.enable_hierarchical_cache and req.last_node is not None:
                if req.last_node.evicted:
                    # loading KV cache for the request
                    req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
                        req.last_node,
                        req.prefix_indices,
                        adder.rem_total_tokens,
                    )
                    if req.last_node.loading:
                        # to prevent frequent cache invalidation
                        if req.rid in self.staging_reqs:
                            self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                        self.tree_cache.inc_lock_ref(req.last_node)
                        self.staging_reqs[req.rid] = req.last_node
                        continue
                elif req.last_node.loading:
                    if not self.tree_cache.loading_complete(req.last_node):
                        continue

                if req.rid in self.staging_reqs:
                    self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                    del self.staging_reqs[req.rid]

1074
            res = adder.add_one_req(req, self.chunked_req)
1075
1076
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1077
1078
1079
1080
1081
1082
1083
1084
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
                        self.batch_is_full = len(adder.can_run_list) > 0 or (
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
                        self.batch_is_full = True
1085
1086
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1087
        # Update waiting queue
1088
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1089
1090
1091
1092
1093
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1094

1095
1096
1097
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1098

1099
1100
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1101

1102
        # Print stats
1103
        if self.attn_tp_rank == 0:
1104
            self.log_prefill_stats(adder, can_run_list, running_bs)
1105

Lianmin Zheng's avatar
Lianmin Zheng committed
1106
        # Create a new batch
1107
1108
1109
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1110
            self.token_to_kv_pool_allocator,
1111
            self.tree_cache,
1112
            self.model_config,
1113
            self.enable_overlap,
1114
            self.spec_algorithm,
1115
            self.server_args.enable_custom_logit_processor,
1116
        )
1117
        new_batch.prepare_for_extend()
1118

Lianmin Zheng's avatar
Lianmin Zheng committed
1119
        # Mixed-style chunked prefill
1120
1121
1122
1123
1124
1125
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1126
1127
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1128
                self.running_batch.prepare_for_decode()
1129
1130
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
1131
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1132
1133
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
1135
1136

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1137
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1138
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1140

1141
1142
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1143
1144
            self.batch_is_full = False
            return None
1145

Lianmin Zheng's avatar
Lianmin Zheng committed
1146
        # Check if decode out of memory
1147
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1148
            TEST_RETRACT and batch.batch_size() > 10
1149
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1150
1151
            old_ratio = self.new_token_ratio

1152
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1153
            self.new_token_ratio = new_token_ratio
1154

Lianmin Zheng's avatar
Lianmin Zheng committed
1155
1156
1157
1158
1159
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1160
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1161
1162
        else:
            self.new_token_ratio = max(
1163
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1164
1165
1166
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1167
1168
        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1169
1170

        # Update batch tensors
1171
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1172
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1173

1174
1175
1176
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1177
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1178
1179
        self.forward_ct += 1

1180
1181
1182
1183
1184
1185
1186
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1187
        # Run forward
1188
        if self.is_generation:
1189
1190
1191
1192
1193
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
1194
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1195
            else:
1196
1197
1198
                (
                    logits_output,
                    next_token_ids,
1199
                    bid,
1200
1201
1202
1203
1204
1205
1206
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1207
            batch.output_ids = next_token_ids
1208

1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1221
1222
1223
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1224
1225
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1226
                bid=bid,
1227
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1228
1229
1230
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1231
1232
1233
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1234
        return ret
Chayenne's avatar
Chayenne committed
1235

1236
1237
1238
1239
1240
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1241
1242
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1243
1244
            if batch.is_empty():
                self.running_batch = None
1245
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1246
            self.process_batch_result_prefill(batch, result)
1247
1248
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1249
                self.tp_worker.resolve_batch_result(result.bid)
1250
1251
1252
1253
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1254
1255
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1256
            self.current_stream.synchronize()
1257
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1258

1259
1260
1261
1262
1263
1264
1265
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1266
1267
1268
1269
1270
    def process_batch_result_prefill(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
1271
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1272

Lianmin Zheng's avatar
Lianmin Zheng committed
1273
        if self.is_generation:
1274
1275
1276
            (
                logits_output,
                next_token_ids,
1277
1278
                extend_input_len_per_req,
                extend_logprob_start_len_per_req,
1279
1280
1281
1282
                bid,
            ) = (
                result.logits_output,
                result.next_token_ids,
1283
1284
                result.extend_input_len_per_req,
                result.extend_logprob_start_len_per_req,
1285
1286
                result.bid,
            )
1287
1288

            if self.enable_overlap:
1289
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1290
1291
            else:
                # Move next_token_ids and logprobs to cpu
1292
                next_token_ids = next_token_ids.tolist()
1293
                if batch.return_logprob:
1294
1295
1296
1297
1298
1299
1300
1301
                    if logits_output.next_token_logprobs is not None:
                        logits_output.next_token_logprobs = (
                            logits_output.next_token_logprobs.tolist()
                        )
                    if logits_output.input_token_logprobs is not None:
                        logits_output.input_token_logprobs = tuple(
                            logits_output.input_token_logprobs.tolist()
                        )
1302

1303
1304
            hidden_state_offset = 0

1305
1306
            # Check finish conditions
            logprob_pt = 0
1307
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1308
1309
1310
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1311
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1312
1313
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
1314
                    self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1315
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1316

1317
1318
                if req.is_chunked <= 0:
                    # req output_ids are set here
1319
                    req.output_ids.append(next_token_id)
1320
1321
                    req.check_finished()

1322
                    if req.finished():
1323
                        self.tree_cache.cache_finished_req(req)
1324
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1325
                        # This updates radix so others can match
1326
1327
                        self.tree_cache.cache_unfinished_req(req)

1328
                    if req.return_logprob:
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
                        assert extend_logprob_start_len_per_req is not None
                        assert extend_input_len_per_req is not None
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        num_input_logprobs = extend_input_len - extend_logprob_start_len
                        self.add_logprob_return_values(
                            i,
                            req,
                            logprob_pt,
                            next_token_ids,
                            num_input_logprobs,
                            logits_output,
1341
                        )
1342
1343
                        logprob_pt += num_input_logprobs

1344
                    if (
1345
                        req.return_hidden_states
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
                        and logits_output.hidden_states is not None
                    ):
                        req.hidden_states.append(
                            logits_output.hidden_states[
                                hidden_state_offset : (
                                    hidden_state_offset := hidden_state_offset
                                    + len(req.origin_input_ids)
                                )
                            ]
                            .cpu()
                            .clone()
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1359
1360
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1361
                        req.grammar.finished = req.finished()
1362
                else:
1363
                    # being chunked reqs' prefill is not finished
1364
                    req.is_chunked -= 1
1365
1366
1367
1368
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1369

1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
                    # Incrementally update input logprobs.
                    if req.return_logprob:
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        if extend_logprob_start_len < extend_input_len:
                            # Update input logprobs.
                            num_input_logprobs = (
                                extend_input_len - extend_logprob_start_len
                            )
                            self.add_input_logprob_return_values(
                                i,
                                req,
                                logits_output,
                                logprob_pt,
                                num_input_logprobs,
                                last_prefill_chunk=False,
                            )
                            logprob_pt += num_input_logprobs

1389
1390
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1391
                self.current_stream.synchronize()
1392
1393
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1394
        else:  # embedding or reward model
1395
            embeddings, bid = result.embeddings, result.bid
1396
            embeddings = embeddings.tolist()
1397
1398
1399

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1400
1401
1402
                if req.is_retracted:
                    continue

1403
                req.embedding = embeddings[i]
1404
                if req.is_chunked <= 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1405
                    # Dummy output token for embedding models
1406
1407
1408
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1409
1410
1411
1412
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1413
                else:
1414
                    # being chunked reqs' prefill is not finished
1415
                    req.is_chunked -= 1
1416

Lianmin Zheng's avatar
Lianmin Zheng committed
1417
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1418

1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
    def process_batch_result_decode(
        self,
        batch: ScheduleBatch,
        result: GenerationBatchResult,
    ):
        logits_output, next_token_ids, bid = (
            result.logits_output,
            result.next_token_ids,
            result.bid,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1429
1430
        self.num_generated_tokens += len(batch.reqs)

1431
        if self.enable_overlap:
1432
            assert batch.spec_algorithm.is_none()
1433
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1434
            next_token_logprobs = logits_output.next_token_logprobs
1435
1436
        elif batch.spec_algorithm.is_none():
            # spec decoding handles output logprobs inside verify process.
1437
            next_token_ids = next_token_ids.tolist()
1438
1439
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1440

1441
        self.token_to_kv_pool_allocator.free_group_begin()
1442

Lianmin Zheng's avatar
Lianmin Zheng committed
1443
        # Check finish condition
1444
1445
        # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
        # We should ignore using next_token_ids for spec decoding cases.
Lianmin Zheng's avatar
Lianmin Zheng committed
1446
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1447
1448
1449
            if req.is_retracted:
                continue

1450
            if self.enable_overlap and req.finished():
1451
                # Free the one delayed token
1452
                self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1453
1454
                continue

1455
1456
1457
1458
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1459
1460
            req.check_finished()
            if req.finished():
1461
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1462

1463
1464
            if req.return_logprob and batch.spec_algorithm.is_none():
                # speculative worker handles logprob in speculative decoding
Lianmin Zheng's avatar
Lianmin Zheng committed
1465
1466
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1467
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1468
                    req.output_top_logprobs_val.append(
1469
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1470
1471
                    )
                    req.output_top_logprobs_idx.append(
1472
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1473
                    )
1474
1475
1476
1477
1478
1479
1480
                if req.token_ids_logprob is not None:
                    req.output_token_ids_logprobs_val.append(
                        logits_output.next_token_token_ids_logprobs_val[i]
                    )
                    req.output_token_ids_logprobs_idx.append(
                        logits_output.next_token_token_ids_logprobs_idx[i]
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1481

1482
            if req.return_hidden_states and logits_output.hidden_states is not None:
1483
1484
                req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())

1485
            if req.grammar is not None and batch.spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
                req.grammar.accept_token(next_token_id)
1487
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1488

1489
1490
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1491
            self.current_stream.synchronize()
1492
            batch.next_batch_sampling_info.sampling_info_done.set()
1493

Lianmin Zheng's avatar
Lianmin Zheng committed
1494
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1495

1496
        self.token_to_kv_pool_allocator.free_group_end()
1497

Lianmin Zheng's avatar
Lianmin Zheng committed
1498
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1499
        if (
1500
            self.attn_tp_rank == 0
Chayenne's avatar
Chayenne committed
1501
1502
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1503
            self.log_decode_stats()
1504

1505
    def add_input_logprob_return_values(
1506
1507
1508
1509
        self,
        i: int,
        req: Req,
        output: LogitsProcessorOutput,
1510
1511
1512
        logprob_pt: int,
        num_input_logprobs: int,
        last_prefill_chunk: bool,  # If True, it means prefill is finished.
1513
    ):
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
        """Incrementally add input logprobs to `req`.

        Args:
            i: The request index in a batch.
            req: The request. Input logprobs inside req are modified as a
                consequence of the API
            fill_ids: The prefill ids processed.
            output: Logit processor output that's used to compute input logprobs
            last_prefill_chunk: True if it is the last prefill (when chunked).
                Some of input logprob operation should only happen at the last
                prefill (e.g., computing input token logprobs).
        """
        assert output.input_token_logprobs is not None
        if req.input_token_logprobs is None:
            req.input_token_logprobs = []
        if req.temp_input_top_logprobs_val is None:
            req.temp_input_top_logprobs_val = []
        if req.temp_input_top_logprobs_idx is None:
            req.temp_input_top_logprobs_idx = []
        if req.temp_input_token_ids_logprobs_val is None:
            req.temp_input_token_ids_logprobs_val = []
        if req.temp_input_token_ids_logprobs_idx is None:
            req.temp_input_token_ids_logprobs_idx = []

        if req.input_token_logprobs_val is not None:
            # The input logprob has been already computed. It only happens
            # upon retract.
            if req.top_logprobs_num > 0:
                assert req.input_token_logprobs_val is not None
            return
1544

1545
1546
1547
1548
1549
1550
1551
        # Important for the performance.
        assert isinstance(output.input_token_logprobs, tuple)
        input_token_logprobs: Tuple[int] = output.input_token_logprobs
        input_token_logprobs = input_token_logprobs[
            logprob_pt : logprob_pt + num_input_logprobs
        ]
        req.input_token_logprobs.extend(input_token_logprobs)
1552

1553
1554
1555
        if req.top_logprobs_num > 0:
            req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
            req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
1556

1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
        if req.token_ids_logprob is not None:
            req.temp_input_token_ids_logprobs_val.append(
                output.input_token_ids_logprobs_val[i]
            )
            req.temp_input_token_ids_logprobs_idx.append(
                output.input_token_ids_logprobs_idx[i]
            )

        if last_prefill_chunk:
            input_token_logprobs = req.input_token_logprobs
            req.input_token_logprobs = None
            assert req.input_token_logprobs_val is None
            assert req.input_token_logprobs_idx is None
            assert req.input_top_logprobs_val is None
            assert req.input_top_logprobs_idx is None

            # Compute input_token_logprobs_val
            # Always pad the first one with None.
            req.input_token_logprobs_val = [None]
            req.input_token_logprobs_val.extend(input_token_logprobs)
            # The last input logprob is for sampling, so just pop it out.
            req.input_token_logprobs_val.pop()

            # Compute input_token_logprobs_idx
            input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1582
1583
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1584
            input_token_logprobs_idx = [
1585
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1586
                for x in input_token_logprobs_idx
1587
            ]
1588
            req.input_token_logprobs_idx = input_token_logprobs_idx
1589

1590
1591
1592
            if req.top_logprobs_num > 0:
                req.input_top_logprobs_val = [None]
                req.input_top_logprobs_idx = [None]
1593
1594
1595
                assert len(req.temp_input_token_ids_logprobs_val) == len(
                    req.temp_input_token_ids_logprobs_idx
                )
1596
                for val, idx in zip(
1597
1598
1599
                    req.temp_input_top_logprobs_val,
                    req.temp_input_top_logprobs_idx,
                    strict=True,
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
                ):
                    req.input_top_logprobs_val.extend(val)
                    req.input_top_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_top_logprobs_val.pop()
                req.input_top_logprobs_idx.pop()
                req.temp_input_top_logprobs_idx = None
                req.temp_input_top_logprobs_val = None

            if req.token_ids_logprob is not None:
                req.input_token_ids_logprobs_val = [None]
                req.input_token_ids_logprobs_idx = [None]

                for val, idx in zip(
                    req.temp_input_token_ids_logprobs_val,
                    req.temp_input_token_ids_logprobs_idx,
                    strict=True,
                ):
                    req.input_token_ids_logprobs_val.extend(val)
                    req.input_token_ids_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_token_ids_logprobs_val.pop()
                req.input_token_ids_logprobs_idx.pop()
                req.temp_input_token_ids_logprobs_idx = None
                req.temp_input_token_ids_logprobs_val = None

            if req.return_logprob:
                relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
                assert len(req.input_token_logprobs_val) == relevant_tokens_len
                assert len(req.input_token_logprobs_idx) == relevant_tokens_len
                if req.top_logprobs_num > 0:
                    assert len(req.input_top_logprobs_val) == relevant_tokens_len
                    assert len(req.input_top_logprobs_idx) == relevant_tokens_len
                if req.token_ids_logprob is not None:
                    assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
                    assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len

    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        num_input_logprobs: int,
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1651

1652
1653
1654
        self.add_input_logprob_return_values(
            i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
        )
1655
1656

        if req.top_logprobs_num > 0:
1657
1658
            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1659

1660
1661
1662
1663
1664
1665
1666
1667
        if req.token_ids_logprob is not None:
            req.output_token_ids_logprobs_val.append(
                output.next_token_token_ids_logprobs_val[i]
            )
            req.output_token_ids_logprobs_idx.append(
                output.next_token_token_ids_logprobs_idx[i]
            )

1668
1669
        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1670
1671
1672
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1673
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1674
1675
1676
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1677
1678
        if self.is_generation:
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1679
1680
            decode_ids_list = []
            read_offsets = []
1681
            output_ids = []
1682

Lianmin Zheng's avatar
Lianmin Zheng committed
1683
1684
1685
1686
1687
1688
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []
1689
            spec_verify_ct = []
1690
            output_hidden_states = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
1701
1702
1703
1704
                input_token_ids_logprobs_val = []
                input_token_ids_logprobs_idx = []
                output_token_ids_logprobs_val = []
                output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1705
1706
1707
1708
1709
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
1710
1711
1712
1713
1714
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    input_token_ids_logprobs_val
                ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
                    output_token_ids_logprobs_idx
                ) = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1715
1716
1717
1718

            for req in reqs:
                if req is skip_req:
                    continue
1719

1720
1721
1722
1723
                # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
                if self.model_config.is_multimodal_gen and req.to_abort:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1724
1725
1726
1727
1728
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1729
1730
1731
1732
1733
1734
1735
                    # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
                    # always increase one-by-one.
                    or (
                        not req.stream
                        and len(req.output_ids) % 50 == 0
                        and not self.model_config.is_multimodal_gen
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1736
1737
1738
1739
1740
                ):
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
1741
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1742
1743
1744
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1745
                    if self.skip_tokenizer_init:
1746
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1747
1748
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1749
1750
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1751
1752
1753
1754
1755
1756
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

1757
1758
1759
                    if not self.spec_algorithm.is_none():
                        spec_verify_ct.append(req.spec_verify_ct)

Lianmin Zheng's avatar
Lianmin Zheng committed
1760
1761
1762
1763
1764
1765
1766
1767
1768
                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
                        input_token_ids_logprobs_val.append(
                            req.input_token_ids_logprobs_val
                        )
                        input_token_ids_logprobs_idx.append(
                            req.input_token_ids_logprobs_idx
                        )
                        output_token_ids_logprobs_val.append(
                            req.output_token_ids_logprobs_val
                        )
                        output_token_ids_logprobs_idx.append(
                            req.output_token_ids_logprobs_idx
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1781

1782
1783
1784
                    if req.return_hidden_states:
                        if output_hidden_states is None:
                            output_hidden_states = []
1785
                        output_hidden_states.append(req.hidden_states)
1786

Lianmin Zheng's avatar
Lianmin Zheng committed
1787
1788
            # Send to detokenizer
            if rids:
1789
1790
                if self.model_config.is_multimodal_gen:
                    raise NotImplementedError()
1791
                self.send_to_detokenizer.send_pyobj(
1792
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1793
1794
                        rids,
                        finished_reasons,
1795
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1796
1797
                        decode_ids_list,
                        read_offsets,
1798
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1799
1800
1801
1802
1803
1804
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
1805
                        spec_verify_ct,
Lianmin Zheng's avatar
Lianmin Zheng committed
1806
1807
1808
1809
1810
1811
1812
1813
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
1814
1815
1816
1817
                        input_token_ids_logprobs_val,
                        input_token_ids_logprobs_idx,
                        output_token_ids_logprobs_val,
                        output_token_ids_logprobs_idx,
1818
                        output_hidden_states,
1819
1820
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1821
1822
1823
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
1824
            cached_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1825
            for req in reqs:
1826
1827
1828
1829
1830
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
1831
                    cached_tokens.append(req.cached_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1832
            self.send_to_detokenizer.send_pyobj(
1833
1834
1835
                BatchEmbeddingOut(
                    rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1836
            )
1837

1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
1864
                    (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1880
            self.token_to_kv_pool_allocator,
1881
1882
1883
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1884
            self.spec_algorithm,
1885
            self.server_args.enable_custom_logit_processor,
1886
1887
1888
1889
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1900
        if self.server_args.enable_dp_attention:
1901
1902
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1903
        else:
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1917

1918
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1919
1920
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool_allocator.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1952
1953
1954
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1955
    def flush_cache(self):
1956
        """Flush the memory pool and cache."""
1957
1958
1959
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
1960
1961
            self.cur_batch = None
            self.last_batch = None
1962
            self.tree_cache.reset()
1963
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1964
                self.grammar_backend.reset()
1965
            self.req_to_token_pool.clear()
1966
            self.token_to_kv_pool_allocator.clear()
1967
1968
1969

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1970
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1971
1972
1973
1974
1975

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1976
1977
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
2043
2044
            logger.debug(f"Abort queued request. {req.rid=}")
            return
2045
2046
2047
2048

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
2049
                if req.rid == recv_req.rid and not req.finished():
2050
2051
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
2052
2053
                    break

2054
2055
2056
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2057
2058
2059
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2060
2061
2062
2063
2064
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2065
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2066

2067
2068
2069
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2070
        return InitWeightsUpdateGroupReqOutput(success, message)
2071
2072

    def update_weights_from_distributed(
2073
2074
2075
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2076
2077
2078
2079
2080
2081
2082
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2083
        return UpdateWeightsFromDistributedReqOutput(success, message)
2084

2085
2086
2087
2088
2089
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2090
2091
2092
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2093
2094
        else:
            logger.error(message)
2095
        return UpdateWeightsFromTensorReqOutput(success, message)
2096

2097
2098
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2099
        return GetWeightsByNameReqOutput(parameter)
2100

2101
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2102
2103
2104
2105
2106
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2107
        return ReleaseMemoryOccupationReqOutput()
2108

2109
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2110
2111
2112
2113
2114
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2115
2116
2117
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
2118
2119
2120
2121
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
                recv_req.output_dir, recv_req.num_steps, recv_req.activities
            )
2122
        else:
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
    ) -> None:
        if self.torch_profiler_activities:
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
        self.torch_profiler_activities = activities
        logger.info(
            "Profiling starts. Traces will be saved to: %s",
            self.torch_profiler_output_dir,
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
                with_stack=True,
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2166

2167
2168
2169
2170
2171
2172
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2173
2174

    def stop_profile(self) -> None:
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
        if self.torch_profiler_activities is None:
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
                    str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
                )
            )

        if "MEM" in self.torch_profiler_activities:
            memory_profile_path = os.path.join(
                self.torch_profiler_trace_dir,
                str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2199
        )
2200
2201
2202
2203
2204
2205
2206
2207
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
        self.torch_profiler_activities = None

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
2208

2209
    def open_session(self, recv_req: OpenSessionReqInput):
2210
2211
2212
2213
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2214
            return OpenSessionReqOutput(session_id, False)
2215
        elif session_id is None:
2216
            logger.warning("session id is None, cannot open.")
2217
            return OpenSessionReqOutput(session_id, False)
2218
2219
2220
2221
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2222
            return OpenSessionReqOutput(session_id, True)
2223
2224
2225
2226
2227
2228
2229
2230
2231

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2232

2233
2234
2235
2236
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2251
2252
2253
2254
2255
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2256
    dp_rank: Optional[int],
2257
    pipe_writer,
2258
):
2259
2260
2261
    # Config the process
    # kill_itself_when_parent_died()  # This is disabled because it does not work for `--dp 2`
    setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
2262
    faulthandler.enable()
2263
    parent_process = psutil.Process().parent()
2264

2265
2266
2267
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2268

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2269
    # Configure the logger
2270
    if dp_rank is None:
2271
        prefix = f" TP{tp_rank}"
2272
    else:
2273
2274
        prefix = f" DP{dp_rank} TP{tp_rank}"
    configure_logger(server_args, prefix=prefix)
2275
    suppress_other_loggers()
2276

2277
    # Set cpu affinity to this gpu process
2278
2279
2280
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2281
    # Create a scheduler and run the event loop
2282
    try:
2283
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2284
        pipe_writer.send(
Mick's avatar
Mick committed
2285
2286
2287
2288
2289
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2290
        )
2291
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
2292
2293
2294
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
2295
    except Exception:
2296
2297
2298
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)