scheduler.py 89.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
35
import zmq

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
42
43
44
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
46
    CloseSessionReqInput,
47
    FlushCacheReq,
48
49
    GetInternalStateReq,
    GetInternalStateReqOutput,
50
51
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
52
    HealthCheckOutput,
53
54
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
55
56
    OpenSessionReqInput,
    OpenSessionReqOutput,
57
    ProfileReq,
58
59
    ProfileReqOutput,
    ProfileReqType,
60
61
62
63
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
64
65
    SetInternalStateReq,
    SetInternalStateReqOutput,
66
67
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
68
69
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
70
71
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
72
73
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
74
75
76
77
78
79
80
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
81
    global_server_args_dict,
82
)
83
84
85
86
87
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
88
from sglang.srt.managers.session_controller import Session
89
from sglang.srt.managers.tp_worker import TpModelWorker
90
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
91
from sglang.srt.managers.utils import validate_input_length
92
from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
94
from sglang.srt.mem_cache.radix_cache import RadixCache
95
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
96
from sglang.srt.model_executor.forward_batch_info import ForwardMode
97
from sglang.srt.server_args import PortArgs, ServerArgs
98
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
100
101
102
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
103
    crash_on_warnings,
104
    get_bool_env_var,
105
    get_zmq_socket,
106
    pyspy_dump_schedulers,
107
    set_gpu_proc_affinity,
108
109
110
    set_random_seed,
    suppress_other_loggers,
)
111
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
112
113
114

logger = logging.getLogger(__name__)

115
# Test retract decode for debugging purposes
116
117
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
118

119

120
121
122
123
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
124
125
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
126
127
128
129
130
131
132
133
134
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


135
136
137
138
139
140
141
142
143
class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
144
        dp_rank: Optional[int],
145
146
    ):
        # Parse args
147
        self.server_args = server_args
148
149
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
150
151
152
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
153
        self.enable_overlap = not server_args.disable_overlap_schedule
154
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
155
        self.enable_metrics = server_args.enable_metrics
156
        self.stream_interval = server_args.stream_interval
157
158
159
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
160
161
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
162
        self.decode_mem_cache_buf_multiplier = (
163
164
165
166
            (
                self.server_args.speculative_num_draft_tokens
                + (
                    self.server_args.speculative_eagle_topk
167
                    * self.server_args.speculative_num_draft_tokens
168
169
                )
            )
170
171
172
            if not self.spec_algorithm.is_none()
            else 1
        )
173

174
        # Distributed rank info
175
176
177
178
179
180
181
182
183
184
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

185
186
        # Init inter-process communication
        context = zmq.Context(2)
187
        if self.attn_tp_rank == 0:
188
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
189
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
190
            )
191
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
192
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
193
            )
194

195
            if server_args.skip_tokenizer_init:
196
                # Directly send to the TokenizerManager
197
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
198
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
199
200
                )
            else:
201
                # Send to the DetokenizerManager
202
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
203
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
204
                )
205
        else:
206
            self.recv_from_tokenizer = None
207
208
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
209
210
211
212

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
213
            trust_remote_code=server_args.trust_remote_code,
214
            revision=server_args.revision,
215
            context_length=server_args.context_length,
216
217
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
218
219
            dtype=server_args.dtype,
            quantization=server_args.quantization,
220
        )
221
        self.is_generation = self.model_config.is_generation
222
223
224
225

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
226
            if self.model_config.is_multimodal:
227
228
229
230
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
231
                    revision=server_args.revision,
232
233
234
235
236
237
238
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
239
                    revision=server_args.revision,
240
                )
241

242
243
244
245
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
246

247
248
249
250
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

251
        # Launch a tensor parallel worker
252
        if self.enable_overlap:
253
            TpWorkerClass = TpModelWorkerClient
254
255
        else:
            TpWorkerClass = TpModelWorker
256

257
        self.tp_worker = TpWorkerClass(
258
            server_args=server_args,
259
260
            gpu_id=gpu_id,
            tp_rank=tp_rank,
261
            dp_rank=dp_rank,
262
            nccl_port=port_args.nccl_port,
263
        )
264

265
        # Launch a draft worker for speculative decoding
266
267
268
269
270
271
272
273
274
275
276
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
277
            self.prefill_only_one_req = True
278
279
        else:
            self.draft_worker = None
280
            self.prefill_only_one_req = False
281

282
        # Get token and memory info from the model worker
283
284
285
286
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
287
            self.max_req_len,
288
289
            self.max_req_input_len,
            self.random_seed,
290
            self.device,
291
292
293
294
295
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
296
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
297
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
298
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
299
        global_server_args_dict.update(worker_global_server_args_dict)
300
        set_random_seed(self.random_seed)
301

302
303
304
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
305
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
306
307
308
309
310
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

311
        # Init memory pool and cache
312
313
314
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )
315
316
317
318
319
320
321

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
322
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
323
324
            )
        else:
325
326
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
327
                    req_to_token_pool=self.req_to_token_pool,
328
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
329
                )
330
331
            else:
                self.tree_cache = RadixCache(
332
                    req_to_token_pool=self.req_to_token_pool,
333
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
334
335
                    disable=server_args.disable_radix_cache,
                )
336

337
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
338
339
340

        # Init running status
        self.waiting_queue: List[Req] = []
341
        self.staging_reqs = {}
342
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
343
        self.running_batch: Optional[ScheduleBatch] = None
344
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
345
        self.cur_batch: Optional[ScheduleBatch] = None
346
347
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
        self.forward_ct = 0
        self.forward_ct_decode = 0
350
        self.num_generated_tokens = 0
351
352
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
353
354
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
355
        self.last_decode_stats_tic = time.time()
356
        self.return_health_check_ct = 0
357
        self.current_stream = torch.get_device_module(self.device).current_stream()
358
359
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
360

361
362
363
364
365
366
367
368
        # For metrics only.
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]

369
        # Session info
370
        self.sessions: Dict[str, Session] = {}
371
372
373

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
374
375
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
376
        self.chunked_req = None
377
378
379
380
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
381
        # Init the grammar backend for constrained generation
382
        self.grammar_queue: List[Req] = []
383
        if not server_args.skip_tokenizer_init:
384
385
386
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
        else:
            self.grammar_backend = None
389
390

        # Init new token estimation
391
392
393
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
394
395
396

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
397
398
            * server_args.schedule_conservativeness,
            1.0,
399
        )
400
401
402
403
404
405
406
407
408
409
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

410
        # Tell whether the current running batch is full so that we can skip
Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
413
        self.batch_is_full = False
414

Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
417
418
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
419
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
420

421
        # Init memory saver
422
423
424
425
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

426
        # Init profiler
427
428
429
430
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.torch_profiler_activities: Optional[List[str]] = None
        self.profiler_target_forward_ct: Optional[int] = None
431

432
        # Init metrics stats
433
434
435
436
437
438
439
440
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
441

442
443
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
444
445
446
447
448
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
449
450
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
451
452
453
454
455
456
457
458
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
459
460
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
461
                (ProfileReq, self.profile),
462
                (GetInternalStateReq, self.get_internal_state),
463
464
465
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
466
    def watchdog_thread(self):
467
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
468
469
470
471
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
472
            current = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
473
474
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
475
                    if current > self.watchdog_last_time + self.watchdog_timeout:
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
478
479
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
480
481
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)
482
483
484
485
486
487
488
489
490
491
492
493

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
494
        time.sleep(5)
495
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
496

497
    @torch.no_grad()
498
    def event_loop_normal(self):
499
        """A normal scheduler loop."""
500
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
503

504
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
505
            self.cur_batch = batch
506
507
508
509

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
510
            else:
511
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
512
                self.check_memory()
513
                self.new_token_ratio = self.init_new_token_ratio
514
515

            self.last_batch = batch
516

517
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
518
    def event_loop_overlap(self):
519
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
520
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
521
522
523
524
525
526
527

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
528

Lianmin Zheng's avatar
Lianmin Zheng committed
529
530
            if batch:
                result = self.run_batch(batch)
531
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
532

533
                if self.last_batch is None:
534
                    # Create a dummy first batch to start the pipeline for overlap schedule.
535
536
537
538
539
540
541
542
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
543
            if self.last_batch:
544
                # Process the results of the last batch
545
                tmp_batch, tmp_result = self.result_queue.popleft()
546
547
548
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
549
550
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
551
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
552
                self.check_memory()
553
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
554
555
556

            self.last_batch = batch

557
558
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
559
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
560
561
            recv_reqs = []

562
563
564
565
566
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
567
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
568
569
        else:
            recv_reqs = None
570

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
605
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
606
607
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
608
    def process_input_requests(self, recv_reqs: List):
609
        for recv_req in recv_reqs:
610
611
612
613
614
615
616
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
                self.chunked_req is not None or self.running_batch is not None
            ):
                self.return_health_check_ct += 1
                continue

617
            output = self._request_dispatcher(recv_req)
618
619
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
620
621
622
623
624

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
625
        # Create a new request
626
627
628
629
630
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
631
632
633
634
635
636
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

637
638
639
640
641
642
643
644
645
646
647
648
649
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

650
651
652
653
654
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
655
656
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
657
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
658
                stream=recv_req.stream,
659
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
660
                input_embeds=recv_req.input_embeds,
661
                custom_logit_processor=custom_logit_processor,
662
                return_hidden_states=recv_req.return_hidden_states,
663
                eos_token_ids=self.model_config.hf_eos_token_id,
664
665
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
666

667
668
669
670
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
671
                req.finished_reason = FINISH_ABORT(
672
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
673
                )
674
                self._add_request_to_queue(req)
675
676
                return
        else:
677
678
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
679
            req = session.create_req(recv_req, self.tokenizer)
680
            if isinstance(req.finished_reason, FINISH_ABORT):
681
                self._add_request_to_queue(req)
682
                return
683

684
        # Handle multimodal inputs
685
        if recv_req.image_inputs is not None:
686
687
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
688
            req.origin_input_ids = self.pad_input_ids_func(
689
                req.origin_input_ids, image_inputs
690
            )
691
            req.extend_image_inputs(image_inputs)
692

693
            if len(req.origin_input_ids) >= self.max_req_input_len:
694
                error_msg = (
695
                    "Multimodal prompt is too long after expanding multimodal tokens. "
696
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
697
                )
698
                logger.error(error_msg)
699
                req.origin_input_ids = [0]
700
                req.image_inputs = None
701
                req.sampling_params.max_new_tokens = 0
702
                req.finished_reason = FINISH_ABORT(
703
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
704
                )
705
                self._add_request_to_queue(req)
706
707
                return

708
709
710
711
712
713
714
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
715
716
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
717
            self._add_request_to_queue(req)
718
            return
719

720
        # Copy more attributes
721
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
722
723
724
725
726
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

727
728
729
730
731
732
733
734
735
736
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

737
738
739
740
741
742
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
743
            self.max_req_len - len(req.origin_input_ids) - 1,
744
745
        )

746
747
748
749
750
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
751
            or req.sampling_params.ebnf is not None
752
            or req.sampling_params.structural_tag is not None
753
754
755
756
757
758
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
759
760
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
761
762
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
763
764
765
766
767
768
769

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
770
771
            self.grammar_queue.append(req)
        else:
772
773
774
775
776
777
778
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
        self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req]):
        self.waiting_queue.extend(reqs)
779
780
781

    def handle_embedding_request(
        self,
782
        recv_req: TokenizedEmbeddingReqInput,
783
784
785
786
787
788
789
790
791
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

792
        # Validate prompts length
793
        error_msg = validate_input_length(
794
795
796
797
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
798
        if error_msg:
799
            self._add_request_to_queue(req)
800
            return
801

802
803
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
804
        self._add_request_to_queue(req)
805

806
807
808
809
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
810
        running_bs: int,
811
    ):
812
        num_used = self.max_total_num_tokens - (
813
814
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
815
        )
816
817
818
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
819

820
        f = (
821
822
823
824
825
826
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
827
            f"#queue-req: {len(self.waiting_queue)}, "
828
        )
829
        logger.info(f)
830
831

        if self.enable_metrics:
832
833
834
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
835
836
837
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
838
839
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
840
841
842
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
843
844
845
846
847
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
848
        num_used = self.max_total_num_tokens - (
849
850
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
851
        )
852
853
854
855
856

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
857

858
859
860
861
862
863
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
864
865
866
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
867
            )
868
            spec_accept_length = 0
869
        else:
870
            spec_accept_length = (
871
872
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
873
874
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
875
876
877
878
879
880
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
881
                f"accept len: {spec_accept_length:.2f}, "
882
883
884
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
885
886
887
            )

        logger.info(msg)
888
889
890
891
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
892
893
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
894
            self.stats.num_queue_reqs = len(self.waiting_queue)
895
            self.stats.spec_accept_length = spec_accept_length
896
897
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
898
899
    def check_memory(self):
        available_size = (
900
901
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
902
        )
903
904
905
906
907
908
909
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
910
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
911
                "KV cache pool leak detected!"
912
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
913
            )
914
915
916
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
917
918

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
919
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
920
                "Memory pool leak detected!"
921
922
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
923
            )
924
925
926
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
927

928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )
            num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

946
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
947
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
948
        if self.last_batch and self.last_batch.forward_mode.is_extend():
949
950
951
952
953
954
955
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
956
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
957

958
            self.last_batch.filter_batch()
959
960
961
962
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
963
                    # merge running_batch with prefill batch
964
                    self.running_batch.merge_batch(self.last_batch)
965

966
967
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
968
969
970
971
972
973
974
975
976
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
            if self.running_batch is None:
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch
977

978
979
980
981
982
        # Handle DP attention
        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret
983

Lianmin Zheng's avatar
Lianmin Zheng committed
984
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
985
        # Check if the grammar is ready in the grammar queue
986
        if self.grammar_queue:
987
            self.move_ready_grammar_requests()
988

Lianmin Zheng's avatar
Lianmin Zheng committed
989
990
991
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
992
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
993
994
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
995
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
996
        if running_bs >= self.max_running_requests:
997
            self.batch_is_full = True
998
999
1000
1001
1002
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1003
        # Prefill policy
1004
1005
        adder = PrefillAdder(
            self.tree_cache,
1006
            self.token_to_kv_pool_allocator,
1007
1008
1009
1010
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1011
            running_bs if self.is_mixed_chunk else 0,
1012
1013
        )

1014
1015
1016
1017
        is_chunked = self.chunked_req is not None
        if is_chunked:
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1018

Lianmin Zheng's avatar
Lianmin Zheng committed
1019
        if self.lora_paths:
1020
1021
1022
1023
1024
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )
1025
        # Get requests from the waiting queue to a new prefill batch
1026
1027
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1028
                self.lora_paths
1029
1030
1031
1032
1033
1034
1035
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
1036
                self.batch_is_full = True
1037
1038
                break

1039
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1040
                self.batch_is_full = True
1041
                break
1042

1043
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067

            if self.enable_hierarchical_cache and req.last_node is not None:
                if req.last_node.evicted:
                    # loading KV cache for the request
                    req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
                        req.last_node,
                        req.prefix_indices,
                        adder.rem_total_tokens,
                    )
                    if req.last_node.loading:
                        # to prevent frequent cache invalidation
                        if req.rid in self.staging_reqs:
                            self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                        self.tree_cache.inc_lock_ref(req.last_node)
                        self.staging_reqs[req.rid] = req.last_node
                        continue
                elif req.last_node.loading:
                    if not self.tree_cache.loading_complete(req.last_node):
                        continue

                if req.rid in self.staging_reqs:
                    self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                    del self.staging_reqs[req.rid]

1068
            res = adder.add_one_req(req, self.chunked_req)
1069
1070
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1071
1072
1073
1074
1075
1076
1077
1078
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
                        self.batch_is_full = len(adder.can_run_list) > 0 or (
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
                        self.batch_is_full = True
1079
                break
1080
            if self.prefill_only_one_req:
1081
                break
1082

Lianmin Zheng's avatar
Lianmin Zheng committed
1083
        # Update waiting queue
1084
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1085
1086
1087
1088
1089
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1090

1091
1092
1093
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1094

1095
1096
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1097

1098
        # Print stats
1099
        if self.attn_tp_rank == 0:
1100
            self.log_prefill_stats(adder, can_run_list, running_bs)
1101

Lianmin Zheng's avatar
Lianmin Zheng committed
1102
        # Create a new batch
1103
1104
1105
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1106
            self.token_to_kv_pool_allocator,
1107
            self.tree_cache,
1108
            self.model_config,
1109
            self.enable_overlap,
1110
            self.spec_algorithm,
1111
            self.server_args.enable_custom_logit_processor,
1112
        )
1113
        new_batch.prepare_for_extend()
1114

Lianmin Zheng's avatar
Lianmin Zheng committed
1115
        # Mixed-style chunked prefill
1116
1117
1118
1119
1120
1121
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1122
1123
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1124
                self.running_batch.prepare_for_decode()
1125
1126
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
1127
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1130
1131
1132

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1133
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1134
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1135
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1136

1137
1138
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
1140
            self.batch_is_full = False
            return None
1141

Lianmin Zheng's avatar
Lianmin Zheng committed
1142
        # Check if decode out of memory
1143
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1144
            TEST_RETRACT and batch.batch_size() > 10
1145
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1146
1147
            old_ratio = self.new_token_ratio

1148
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1149
            self.new_token_ratio = new_token_ratio
1150

Lianmin Zheng's avatar
Lianmin Zheng committed
1151
1152
1153
1154
1155
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1156
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1157
1158
        else:
            self.new_token_ratio = max(
1159
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1160
1161
1162
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1163
1164
        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1165
1166

        # Update batch tensors
1167
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1168
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1169

1170
1171
1172
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1173
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1174
1175
        self.forward_ct += 1

1176
1177
1178
1179
1180
1181
1182
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1183
        if self.is_generation:
1184
1185
1186
1187
1188
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
1189
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1190
            else:
1191
1192
1193
                (
                    logits_output,
                    next_token_ids,
1194
                    bid,
1195
1196
1197
1198
1199
1200
1201
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1202
            batch.output_ids = next_token_ids
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1215
1216
1217
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1218
1219
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1220
                bid=bid,
1221
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1222
1223
1224
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1225
1226
1227
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1228
        return ret
Chayenne's avatar
Chayenne committed
1229

1230
1231
1232
1233
1234
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
        if batch.forward_mode.is_decode():
1236
            assert isinstance(result, GenerationBatchResult)
Lianmin Zheng's avatar
Lianmin Zheng committed
1237
            self.process_batch_result_decode(batch, result)
1238
1239
            if batch.is_empty():
                self.running_batch = None
1240
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1241
            self.process_batch_result_prefill(batch, result)
1242
1243
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1244
                self.tp_worker.resolve_batch_result(result.bid)
1245
1246
1247
1248
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1249
1250
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1251
            self.current_stream.synchronize()
1252
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1253

1254
1255
1256
1257
1258
1259
1260
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1261
1262
1263
1264
1265
    def process_batch_result_prefill(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
1266
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1267

Lianmin Zheng's avatar
Lianmin Zheng committed
1268
        if self.is_generation:
1269
1270
1271
            (
                logits_output,
                next_token_ids,
1272
1273
                extend_input_len_per_req,
                extend_logprob_start_len_per_req,
1274
1275
1276
1277
                bid,
            ) = (
                result.logits_output,
                result.next_token_ids,
1278
1279
                result.extend_input_len_per_req,
                result.extend_logprob_start_len_per_req,
1280
1281
                result.bid,
            )
1282
1283

            if self.enable_overlap:
1284
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1285
1286
            else:
                # Move next_token_ids and logprobs to cpu
1287
                next_token_ids = next_token_ids.tolist()
1288
                if batch.return_logprob:
1289
1290
1291
1292
1293
1294
1295
1296
                    if logits_output.next_token_logprobs is not None:
                        logits_output.next_token_logprobs = (
                            logits_output.next_token_logprobs.tolist()
                        )
                    if logits_output.input_token_logprobs is not None:
                        logits_output.input_token_logprobs = tuple(
                            logits_output.input_token_logprobs.tolist()
                        )
1297

1298
1299
            hidden_state_offset = 0

1300
1301
            # Check finish conditions
            logprob_pt = 0
1302
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1303
1304
1305
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1306
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1307
1308
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
1309
                    self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1310
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1311

1312
1313
                if req.is_chunked <= 0:
                    # req output_ids are set here
1314
                    req.output_ids.append(next_token_id)
1315
1316
                    req.check_finished()

1317
                    if req.finished():
1318
                        self.tree_cache.cache_finished_req(req)
1319
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1320
                        # This updates radix so others can match
1321
1322
                        self.tree_cache.cache_unfinished_req(req)

1323
                    if req.return_logprob:
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
                        assert extend_logprob_start_len_per_req is not None
                        assert extend_input_len_per_req is not None
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        num_input_logprobs = extend_input_len - extend_logprob_start_len
                        self.add_logprob_return_values(
                            i,
                            req,
                            logprob_pt,
                            next_token_ids,
                            num_input_logprobs,
                            logits_output,
1336
                        )
1337
1338
                        logprob_pt += num_input_logprobs

1339
                    if (
1340
                        req.return_hidden_states
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
                        and logits_output.hidden_states is not None
                    ):
                        req.hidden_states.append(
                            logits_output.hidden_states[
                                hidden_state_offset : (
                                    hidden_state_offset := hidden_state_offset
                                    + len(req.origin_input_ids)
                                )
                            ]
                            .cpu()
                            .clone()
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1354
1355
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1356
                        req.grammar.finished = req.finished()
1357
                else:
1358
                    # being chunked reqs' prefill is not finished
1359
                    req.is_chunked -= 1
1360
1361
1362
1363
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1364

1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
                    # Incrementally update input logprobs.
                    if req.return_logprob:
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        if extend_logprob_start_len < extend_input_len:
                            # Update input logprobs.
                            num_input_logprobs = (
                                extend_input_len - extend_logprob_start_len
                            )
                            self.add_input_logprob_return_values(
                                i,
                                req,
                                logits_output,
                                logprob_pt,
                                num_input_logprobs,
                                last_prefill_chunk=False,
                            )
                            logprob_pt += num_input_logprobs

1384
1385
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1386
                self.current_stream.synchronize()
1387
1388
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1389
        else:  # embedding or reward model
1390
            embeddings, bid = result.embeddings, result.bid
1391
            embeddings = embeddings.tolist()
1392
1393
1394

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1395
1396
1397
                if req.is_retracted:
                    continue

1398
                req.embedding = embeddings[i]
1399
                if req.is_chunked <= 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1400
                    # Dummy output token for embedding models
1401
1402
1403
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1404
1405
1406
1407
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1408
                else:
1409
                    # being chunked reqs' prefill is not finished
1410
                    req.is_chunked -= 1
1411

Lianmin Zheng's avatar
Lianmin Zheng committed
1412
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1413

1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
    def process_batch_result_decode(
        self,
        batch: ScheduleBatch,
        result: GenerationBatchResult,
    ):
        logits_output, next_token_ids, bid = (
            result.logits_output,
            result.next_token_ids,
            result.bid,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1424
1425
        self.num_generated_tokens += len(batch.reqs)

1426
        if self.enable_overlap:
1427
            assert batch.spec_algorithm.is_none()
1428
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1429
            next_token_logprobs = logits_output.next_token_logprobs
1430
1431
        elif batch.spec_algorithm.is_none():
            # spec decoding handles output logprobs inside verify process.
1432
            next_token_ids = next_token_ids.tolist()
1433
1434
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1435

1436
        self.token_to_kv_pool_allocator.free_group_begin()
1437

Lianmin Zheng's avatar
Lianmin Zheng committed
1438
        # Check finish condition
1439
1440
        # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
        # We should ignore using next_token_ids for spec decoding cases.
Lianmin Zheng's avatar
Lianmin Zheng committed
1441
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1442
1443
1444
            if req.is_retracted:
                continue

1445
            if self.enable_overlap and req.finished():
1446
                # Free the one delayed token
1447
                self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1448
1449
                continue

1450
1451
1452
1453
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1454
1455
            req.check_finished()
            if req.finished():
1456
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1457

1458
1459
            if req.return_logprob and batch.spec_algorithm.is_none():
                # speculative worker handles logprob in speculative decoding
Lianmin Zheng's avatar
Lianmin Zheng committed
1460
1461
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1462
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
                    req.output_top_logprobs_val.append(
1464
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1465
1466
                    )
                    req.output_top_logprobs_idx.append(
1467
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1468
                    )
1469
1470
1471
1472
1473
1474
1475
                if req.token_ids_logprob is not None:
                    req.output_token_ids_logprobs_val.append(
                        logits_output.next_token_token_ids_logprobs_val[i]
                    )
                    req.output_token_ids_logprobs_idx.append(
                        logits_output.next_token_token_ids_logprobs_idx[i]
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1476

1477
            if req.return_hidden_states and logits_output.hidden_states is not None:
1478
1479
                req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())

1480
            if req.grammar is not None and batch.spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1481
                req.grammar.accept_token(next_token_id)
1482
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1483

1484
1485
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1486
            self.current_stream.synchronize()
1487
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1488
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1489

1490
        self.token_to_kv_pool_allocator.free_group_end()
1491

Lianmin Zheng's avatar
Lianmin Zheng committed
1492
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1493
        if (
1494
            self.attn_tp_rank == 0
Chayenne's avatar
Chayenne committed
1495
1496
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1497
            self.log_decode_stats()
1498

1499
    def add_input_logprob_return_values(
1500
1501
1502
1503
        self,
        i: int,
        req: Req,
        output: LogitsProcessorOutput,
1504
1505
1506
        logprob_pt: int,
        num_input_logprobs: int,
        last_prefill_chunk: bool,  # If True, it means prefill is finished.
1507
    ):
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
        """Incrementally add input logprobs to `req`.

        Args:
            i: The request index in a batch.
            req: The request. Input logprobs inside req are modified as a
                consequence of the API
            fill_ids: The prefill ids processed.
            output: Logit processor output that's used to compute input logprobs
            last_prefill_chunk: True if it is the last prefill (when chunked).
                Some of input logprob operation should only happen at the last
                prefill (e.g., computing input token logprobs).
        """
        assert output.input_token_logprobs is not None
        if req.input_token_logprobs is None:
            req.input_token_logprobs = []
        if req.temp_input_top_logprobs_val is None:
            req.temp_input_top_logprobs_val = []
        if req.temp_input_top_logprobs_idx is None:
            req.temp_input_top_logprobs_idx = []
        if req.temp_input_token_ids_logprobs_val is None:
            req.temp_input_token_ids_logprobs_val = []
        if req.temp_input_token_ids_logprobs_idx is None:
            req.temp_input_token_ids_logprobs_idx = []

        if req.input_token_logprobs_val is not None:
            # The input logprob has been already computed. It only happens
            # upon retract.
            if req.top_logprobs_num > 0:
                assert req.input_token_logprobs_val is not None
            return
1538

1539
1540
1541
1542
1543
1544
1545
        # Important for the performance.
        assert isinstance(output.input_token_logprobs, tuple)
        input_token_logprobs: Tuple[int] = output.input_token_logprobs
        input_token_logprobs = input_token_logprobs[
            logprob_pt : logprob_pt + num_input_logprobs
        ]
        req.input_token_logprobs.extend(input_token_logprobs)
1546

1547
1548
1549
        if req.top_logprobs_num > 0:
            req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
            req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
1550

1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
        if req.token_ids_logprob is not None:
            req.temp_input_token_ids_logprobs_val.append(
                output.input_token_ids_logprobs_val[i]
            )
            req.temp_input_token_ids_logprobs_idx.append(
                output.input_token_ids_logprobs_idx[i]
            )

        if last_prefill_chunk:
            input_token_logprobs = req.input_token_logprobs
            req.input_token_logprobs = None
            assert req.input_token_logprobs_val is None
            assert req.input_token_logprobs_idx is None
            assert req.input_top_logprobs_val is None
            assert req.input_top_logprobs_idx is None

            # Compute input_token_logprobs_val
            # Always pad the first one with None.
            req.input_token_logprobs_val = [None]
            req.input_token_logprobs_val.extend(input_token_logprobs)
            # The last input logprob is for sampling, so just pop it out.
            req.input_token_logprobs_val.pop()

            # Compute input_token_logprobs_idx
            input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1576
1577
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1578
            input_token_logprobs_idx = [
1579
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1580
                for x in input_token_logprobs_idx
1581
            ]
1582
            req.input_token_logprobs_idx = input_token_logprobs_idx
1583

1584
1585
1586
            if req.top_logprobs_num > 0:
                req.input_top_logprobs_val = [None]
                req.input_top_logprobs_idx = [None]
1587
1588
1589
                assert len(req.temp_input_token_ids_logprobs_val) == len(
                    req.temp_input_token_ids_logprobs_idx
                )
1590
                for val, idx in zip(
1591
                    req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
                ):
                    req.input_top_logprobs_val.extend(val)
                    req.input_top_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_top_logprobs_val.pop()
                req.input_top_logprobs_idx.pop()
                req.temp_input_top_logprobs_idx = None
                req.temp_input_top_logprobs_val = None

            if req.token_ids_logprob is not None:
                req.input_token_ids_logprobs_val = [None]
                req.input_token_ids_logprobs_idx = [None]

                for val, idx in zip(
                    req.temp_input_token_ids_logprobs_val,
                    req.temp_input_token_ids_logprobs_idx,
                    strict=True,
                ):
                    req.input_token_ids_logprobs_val.extend(val)
                    req.input_token_ids_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_token_ids_logprobs_val.pop()
                req.input_token_ids_logprobs_idx.pop()
                req.temp_input_token_ids_logprobs_idx = None
                req.temp_input_token_ids_logprobs_val = None

            if req.return_logprob:
                relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
                assert len(req.input_token_logprobs_val) == relevant_tokens_len
                assert len(req.input_token_logprobs_idx) == relevant_tokens_len
                if req.top_logprobs_num > 0:
                    assert len(req.input_top_logprobs_val) == relevant_tokens_len
                    assert len(req.input_top_logprobs_idx) == relevant_tokens_len
                if req.token_ids_logprob is not None:
                    assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
                    assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len

    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        num_input_logprobs: int,
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1643

1644
1645
1646
        self.add_input_logprob_return_values(
            i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
        )
1647
1648

        if req.top_logprobs_num > 0:
1649
1650
            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1651

1652
1653
1654
1655
1656
1657
1658
1659
        if req.token_ids_logprob is not None:
            req.output_token_ids_logprobs_val.append(
                output.next_token_token_ids_logprobs_val[i]
            )
            req.output_token_ids_logprobs_idx.append(
                output.next_token_token_ids_logprobs_idx[i]
            )

1660
1661
        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1662
1663
1664
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1665
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1666
1667
1668
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1669
1670
        if self.is_generation:
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1671
1672
            decode_ids_list = []
            read_offsets = []
1673
            output_ids = []
1674

Lianmin Zheng's avatar
Lianmin Zheng committed
1675
1676
1677
1678
1679
1680
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []
1681
            spec_verify_ct = []
1682
            output_hidden_states = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
1693
1694
1695
1696
                input_token_ids_logprobs_val = []
                input_token_ids_logprobs_idx = []
                output_token_ids_logprobs_val = []
                output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1697
1698
1699
1700
1701
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
1702
1703
1704
1705
1706
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    input_token_ids_logprobs_val
                ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
                    output_token_ids_logprobs_idx
                ) = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1707
1708
1709
1710

            for req in reqs:
                if req is skip_req:
                    continue
1711

1712
1713
1714
1715
                # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
                if self.model_config.is_multimodal_gen and req.to_abort:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1716
1717
1718
1719
1720
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1721
1722
1723
1724
1725
1726
1727
                    # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
                    # always increase one-by-one.
                    or (
                        not req.stream
                        and len(req.output_ids) % 50 == 0
                        and not self.model_config.is_multimodal_gen
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
1730
1731
1732
                ):
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
1733
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1734
1735
1736
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1737
                    if self.skip_tokenizer_init:
1738
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1739
1740
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1741
1742
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1743
1744
1745
1746
1747
1748
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

1749
1750
1751
                    if not self.spec_algorithm.is_none():
                        spec_verify_ct.append(req.spec_verify_ct)

Lianmin Zheng's avatar
Lianmin Zheng committed
1752
1753
1754
1755
1756
1757
1758
1759
1760
                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
                        input_token_ids_logprobs_val.append(
                            req.input_token_ids_logprobs_val
                        )
                        input_token_ids_logprobs_idx.append(
                            req.input_token_ids_logprobs_idx
                        )
                        output_token_ids_logprobs_val.append(
                            req.output_token_ids_logprobs_val
                        )
                        output_token_ids_logprobs_idx.append(
                            req.output_token_ids_logprobs_idx
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1773

1774
1775
1776
                    if req.return_hidden_states:
                        if output_hidden_states is None:
                            output_hidden_states = []
1777
                        output_hidden_states.append(req.hidden_states)
1778

Lianmin Zheng's avatar
Lianmin Zheng committed
1779
1780
            # Send to detokenizer
            if rids:
1781
1782
                if self.model_config.is_multimodal_gen:
                    raise NotImplementedError()
1783
                self.send_to_detokenizer.send_pyobj(
1784
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1785
1786
                        rids,
                        finished_reasons,
1787
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1788
1789
                        decode_ids_list,
                        read_offsets,
1790
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
1792
1793
1794
1795
1796
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
1797
                        spec_verify_ct,
Lianmin Zheng's avatar
Lianmin Zheng committed
1798
1799
1800
1801
1802
1803
1804
1805
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
1806
1807
1808
1809
                        input_token_ids_logprobs_val,
                        input_token_ids_logprobs_idx,
                        output_token_ids_logprobs_val,
                        output_token_ids_logprobs_idx,
1810
                        output_hidden_states,
1811
1812
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1813
1814
1815
1816
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
1817
1818
1819
1820
1821
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
Lianmin Zheng's avatar
Lianmin Zheng committed
1822
1823
1824
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1825

1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
1852
                    (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1868
            self.token_to_kv_pool_allocator,
1869
1870
1871
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1872
            self.spec_algorithm,
1873
            self.server_args.enable_custom_logit_processor,
1874
1875
1876
1877
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1888
        if self.server_args.enable_dp_attention:
1889
1890
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1891
        else:
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1905

1906
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1907
1908
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1909
1910
1911
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1912
    def flush_cache(self):
1913
        """Flush the memory pool and cache."""
1914
1915
1916
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
1917
1918
            self.cur_batch = None
            self.last_batch = None
1919
1920
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1921
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1922
                self.grammar_backend.reset()
1923
            self.req_to_token_pool.clear()
1924
            self.token_to_kv_pool_allocator.clear()
1925
1926
1927

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1928
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1929
1930
1931
1932
1933

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1934
1935
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
2001
2002
            logger.debug(f"Abort queued request. {req.rid=}")
            return
2003
2004
2005
2006

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
2007
                if req.rid == recv_req.rid and not req.finished():
2008
2009
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
2010
2011
                    break

Chayenne's avatar
Chayenne committed
2012
2013
2014
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2015
2016
2017
2018
2019
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2020
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2021

2022
2023
2024
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2025
        return InitWeightsUpdateGroupReqOutput(success, message)
2026
2027

    def update_weights_from_distributed(
2028
2029
2030
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2031
2032
2033
2034
2035
2036
2037
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2038
        return UpdateWeightsFromDistributedReqOutput(success, message)
2039

2040
2041
2042
2043
2044
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2045
2046
2047
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2048
2049
        else:
            logger.error(message)
2050
        return UpdateWeightsFromTensorReqOutput(success, message)
2051

2052
2053
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2054
        return GetWeightsByNameReqOutput(parameter)
2055

2056
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2057
2058
2059
2060
2061
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2062
        return ReleaseMemoryOccupationReqOutput()
2063

2064
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2065
2066
2067
2068
2069
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2070
2071
2072
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
2073
2074
2075
2076
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
                recv_req.output_dir, recv_req.num_steps, recv_req.activities
            )
2077
        else:
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
    ) -> None:
        if self.torch_profiler_activities:
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
        self.torch_profiler_activities = activities
        logger.info(
            "Profiling starts. Traces will be saved to: %s",
            self.torch_profiler_output_dir,
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
                with_stack=True,
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2121

2122
2123
2124
2125
2126
2127
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2128
2129

    def stop_profile(self) -> None:
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
        if self.torch_profiler_activities is None:
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
                    str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
                )
            )

        if "MEM" in self.torch_profiler_activities:
            memory_profile_path = os.path.join(
                self.torch_profiler_trace_dir,
                str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2154
        )
2155
2156
2157
2158
2159
2160
2161
2162
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
        self.torch_profiler_activities = None

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
2163

2164
    def open_session(self, recv_req: OpenSessionReqInput):
2165
2166
2167
2168
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2169
            return OpenSessionReqOutput(session_id, False)
2170
        elif session_id is None:
2171
            logger.warning("session id is None, cannot open.")
2172
            return OpenSessionReqOutput(session_id, False)
2173
2174
2175
2176
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2177
            return OpenSessionReqOutput(session_id, True)
2178
2179
2180
2181
2182
2183
2184
2185
2186

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2187

2188
2189
2190
2191
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2206
2207
2208
2209
2210
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2211
    dp_rank: Optional[int],
2212
    pipe_writer,
2213
):
2214
2215
2216
    # Config the process
    # kill_itself_when_parent_died()  # This is disabled because it does not work for `--dp 2`
    setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
2217
    faulthandler.enable()
2218
    parent_process = psutil.Process().parent()
2219

2220
2221
2222
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2223

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2224
    # Configure the logger
2225
    if dp_rank is None:
2226
        prefix = f" TP{tp_rank}"
2227
    else:
2228
2229
        prefix = f" DP{dp_rank} TP{tp_rank}"
    configure_logger(server_args, prefix=prefix)
2230
    suppress_other_loggers()
2231

2232
    # Set cpu affinity to this gpu process
2233
2234
2235
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2236
    # Create a scheduler and run the event loop
2237
    try:
2238
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2239
        pipe_writer.send(
Mick's avatar
Mick committed
2240
2241
2242
2243
2244
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2245
        )
2246
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
2247
2248
2249
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
2250
    except Exception:
2251
2252
2253
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)