scheduler.py 89.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
35
import zmq

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
42
43
44
45
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
46
    CloseSessionReqInput,
47
    FlushCacheReq,
48
49
    GetInternalStateReq,
    GetInternalStateReqOutput,
50
51
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
52
    HealthCheckOutput,
53
54
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
55
56
    OpenSessionReqInput,
    OpenSessionReqOutput,
57
    ProfileReq,
58
59
    ProfileReqOutput,
    ProfileReqType,
60
61
62
63
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
64
65
    SetInternalStateReq,
    SetInternalStateReqOutput,
66
67
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
68
69
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
70
71
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
72
73
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
74
75
76
77
78
79
80
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
81
    global_server_args_dict,
82
)
83
84
85
86
87
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
88
from sglang.srt.managers.session_controller import Session
89
from sglang.srt.managers.tp_worker import TpModelWorker
90
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
91
from sglang.srt.managers.utils import validate_input_length
92
from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
94
from sglang.srt.mem_cache.radix_cache import RadixCache
95
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
96
from sglang.srt.model_executor.forward_batch_info import ForwardMode
97
from sglang.srt.server_args import PortArgs, ServerArgs
98
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
100
101
102
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
103
    crash_on_warnings,
104
    get_bool_env_var,
105
    get_zmq_socket,
106
    pyspy_dump_schedulers,
107
    set_gpu_proc_affinity,
108
109
110
    set_random_seed,
    suppress_other_loggers,
)
111
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
112
113
114

logger = logging.getLogger(__name__)

115
# Test retract decode for debugging purposes
116
117
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
118

119

120
121
122
123
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
124
125
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
126
127
128
129
130
131
132
133
134
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


135
136
137
138
139
140
141
142
143
class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
144
        dp_rank: Optional[int],
145
146
    ):
        # Parse args
147
        self.server_args = server_args
148
149
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
150
151
152
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
153
        self.enable_overlap = not server_args.disable_overlap_schedule
154
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
155
        self.enable_metrics = server_args.enable_metrics
156
        self.stream_interval = server_args.stream_interval
157
158
159
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
160
161
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
162
        self.decode_mem_cache_buf_multiplier = (
163
164
165
166
167
168
169
            (
                self.server_args.speculative_num_draft_tokens
                + (
                    self.server_args.speculative_eagle_topk
                    * self.server_args.speculative_num_steps
                )
            )
170
171
172
            if not self.spec_algorithm.is_none()
            else 1
        )
173

174
        # Distributed rank info
175
176
177
178
179
180
181
182
183
184
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

185
186
        # Init inter-process communication
        context = zmq.Context(2)
187
        if self.attn_tp_rank == 0:
188
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
189
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
190
            )
191
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
192
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
193
            )
194

195
            if server_args.skip_tokenizer_init:
196
                # Directly send to the TokenizerManager
197
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
198
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
199
200
                )
            else:
201
                # Send to the DetokenizerManager
202
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
203
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
204
                )
205
        else:
206
            self.recv_from_tokenizer = None
207
208
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
209
210
211
212

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
213
            trust_remote_code=server_args.trust_remote_code,
214
            revision=server_args.revision,
215
            context_length=server_args.context_length,
216
217
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
218
219
            dtype=server_args.dtype,
            quantization=server_args.quantization,
220
        )
221
        self.is_generation = self.model_config.is_generation
222
223
224
225

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
226
            if self.model_config.is_multimodal:
227
228
229
230
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
231
                    revision=server_args.revision,
232
233
234
235
236
237
238
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
239
                    revision=server_args.revision,
240
                )
241

242
243
244
245
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
246

247
248
249
250
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

251
        # Launch a tensor parallel worker
252
        if self.enable_overlap:
253
            TpWorkerClass = TpModelWorkerClient
254
255
        else:
            TpWorkerClass = TpModelWorker
256

257
        self.tp_worker = TpWorkerClass(
258
            server_args=server_args,
259
260
            gpu_id=gpu_id,
            tp_rank=tp_rank,
261
            dp_rank=dp_rank,
262
            nccl_port=port_args.nccl_port,
263
        )
264

265
        # Launch a draft worker for speculative decoding
266
267
268
269
270
271
272
273
274
275
276
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
277
            self.prefill_only_one_req = True
278
279
        else:
            self.draft_worker = None
280
            self.prefill_only_one_req = False
281

282
        # Get token and memory info from the model worker
283
284
285
286
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
287
            self.max_req_len,
288
289
            self.max_req_input_len,
            self.random_seed,
290
            self.device,
291
292
293
294
295
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
296
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
297
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
298
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
299
        global_server_args_dict.update(worker_global_server_args_dict)
300
        set_random_seed(self.random_seed)
301

302
303
304
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
305
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
306
307
308
309
310
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

311
312
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
313
314
315
316
317
318
319
320
321
322

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
323
324
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
325
326
327
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool,
                )
328
329
            else:
                self.tree_cache = RadixCache(
330
331
332
333
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool,
                    disable=server_args.disable_radix_cache,
                )
334

335
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
336
337
338

        # Init running status
        self.waiting_queue: List[Req] = []
339
        self.staging_reqs = {}
340
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
341
        self.running_batch: Optional[ScheduleBatch] = None
342
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
343
        self.cur_batch: Optional[ScheduleBatch] = None
344
345
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
        self.forward_ct = 0
        self.forward_ct_decode = 0
348
        self.num_generated_tokens = 0
349
350
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
351
352
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
353
        self.last_decode_stats_tic = time.time()
354
        self.return_health_check_ct = 0
355
        self.current_stream = torch.get_device_module(self.device).current_stream()
356
357
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
358

359
360
361
362
363
364
365
366
        # For metrics only.
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]

367
        # Session info
368
        self.sessions: Dict[str, Session] = {}
369
370
371

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
372
373
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
374
        self.chunked_req = None
375
376
377
378
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
379
        # Init the grammar backend for constrained generation
380
        self.grammar_queue: List[Req] = []
381
        if not server_args.skip_tokenizer_init:
382
383
384
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
        else:
            self.grammar_backend = None
387
388

        # Init new token estimation
389
390
391
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
392
393
394

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
395
396
            * server_args.schedule_conservativeness,
            1.0,
397
        )
398
399
400
401
402
403
404
405
406
407
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

408
        # Tell whether the current running batch is full so that we can skip
Lianmin Zheng's avatar
Lianmin Zheng committed
409
410
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
411
        self.batch_is_full = False
412

Lianmin Zheng's avatar
Lianmin Zheng committed
413
414
415
416
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
417
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
418

419
        # Init memory saver
420
421
422
423
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

424
        # Init profiler
425
426
427
428
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.torch_profiler_activities: Optional[List[str]] = None
        self.profiler_target_forward_ct: Optional[int] = None
429

430
        # Init metrics stats
431
432
433
434
435
436
437
438
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
439

440
441
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
442
443
444
445
446
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
447
448
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
449
450
451
452
453
454
455
456
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
457
458
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
459
                (ProfileReq, self.profile),
460
461
                (GetInternalStateReq, self.get_internal_state),
                (SetInternalStateReq, self.set_internal_state),
462
463
464
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
465
    def watchdog_thread(self):
466
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
467
468
469
470
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
471
            current = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
472
473
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
474
                    if current > self.watchdog_last_time + self.watchdog_timeout:
Lianmin Zheng's avatar
Lianmin Zheng committed
475
476
477
478
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
479
480
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)
481
482
483
484
485
486
487
488
489
490
491
492

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
493
        time.sleep(5)
494
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
495

496
    @torch.no_grad()
497
    def event_loop_normal(self):
498
        """A normal scheduler loop."""
499
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
500
501
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
502

503
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
504
            self.cur_batch = batch
505
506
507
508

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
509
            else:
510
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
511
                self.check_memory()
512
                self.new_token_ratio = self.init_new_token_ratio
513
514

            self.last_batch = batch
515

516
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
517
    def event_loop_overlap(self):
518
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
519
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
520
521
522
523
524
525
526

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
527

Lianmin Zheng's avatar
Lianmin Zheng committed
528
529
            if batch:
                result = self.run_batch(batch)
530
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
531

532
                if self.last_batch is None:
533
                    # Create a dummy first batch to start the pipeline for overlap schedule.
534
535
536
537
538
539
540
541
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
542
            if self.last_batch:
543
                # Process the results of the last batch
544
                tmp_batch, tmp_result = self.result_queue.popleft()
545
546
547
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
548
549
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
550
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
551
                self.check_memory()
552
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
553
554
555

            self.last_batch = batch

556
557
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
558
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
559
560
            recv_reqs = []

561
562
563
564
565
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
566
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
567
568
        else:
            recv_reqs = None
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
604
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
605
606
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
607
    def process_input_requests(self, recv_reqs: List):
608
        for recv_req in recv_reqs:
609
610
611
612
613
614
615
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
                self.chunked_req is not None or self.running_batch is not None
            ):
                self.return_health_check_ct += 1
                continue

616
            output = self._request_dispatcher(recv_req)
617
618
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
619
620
621
622
623

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
624
        # Create a new request
625
626
627
628
629
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
630
631
632
633
634
635
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

636
637
638
639
640
641
642
643
644
645
646
647
648
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

649
650
651
652
653
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
654
655
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
656
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
657
                stream=recv_req.stream,
658
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
659
                input_embeds=recv_req.input_embeds,
660
                custom_logit_processor=custom_logit_processor,
661
                return_hidden_states=recv_req.return_hidden_states,
662
                eos_token_ids=self.model_config.hf_eos_token_id,
663
664
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
665

666
667
668
669
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
670
                req.finished_reason = FINISH_ABORT(
671
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
672
                )
673
                self._add_request_to_queue(req)
674
675
                return
        else:
676
677
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
678
            req = session.create_req(recv_req, self.tokenizer)
679
            if isinstance(req.finished_reason, FINISH_ABORT):
680
                self._add_request_to_queue(req)
681
                return
682

683
        # Handle multimodal inputs
684
        if recv_req.image_inputs is not None:
685
686
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
687
            req.origin_input_ids = self.pad_input_ids_func(
688
                req.origin_input_ids, image_inputs
689
            )
690
            req.extend_image_inputs(image_inputs)
691

692
            if len(req.origin_input_ids) >= self.max_req_input_len:
693
                error_msg = (
694
                    "Multimodal prompt is too long after expanding multimodal tokens. "
695
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
696
                )
697
                logger.error(error_msg)
698
                req.origin_input_ids = [0]
699
                req.image_inputs = None
700
                req.sampling_params.max_new_tokens = 0
701
                req.finished_reason = FINISH_ABORT(
702
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
703
                )
704
                self._add_request_to_queue(req)
705
706
                return

707
708
709
710
711
712
713
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
714
715
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
716
            self._add_request_to_queue(req)
717
            return
718

719
        # Copy more attributes
720
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
721
722
723
724
725
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

726
727
728
729
730
731
732
733
734
735
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

736
737
738
739
740
741
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
742
            self.max_req_len - len(req.origin_input_ids) - 1,
743
744
        )

745
746
747
748
749
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
750
            or req.sampling_params.ebnf is not None
751
            or req.sampling_params.structural_tag is not None
752
753
754
755
756
757
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
758
759
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
760
761
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
762
763
764
765
766
767
768

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
769
770
            self.grammar_queue.append(req)
        else:
771
772
773
774
775
776
777
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
        self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req]):
        self.waiting_queue.extend(reqs)
778
779
780

    def handle_embedding_request(
        self,
781
        recv_req: TokenizedEmbeddingReqInput,
782
783
784
785
786
787
788
789
790
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

791
        # Validate prompts length
792
        error_msg = validate_input_length(
793
794
795
796
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
797
        if error_msg:
798
            self._add_request_to_queue(req)
799
            return
800

801
802
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
803
        self._add_request_to_queue(req)
804

805
806
807
808
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
809
        running_bs: int,
810
    ):
811
812
813
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
814
815
816
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
817

818
        f = (
819
820
821
822
823
824
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
825
            f"#queue-req: {len(self.waiting_queue)}, "
826
        )
827
        logger.info(f)
828
829

        if self.enable_metrics:
830
831
832
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
833
834
835
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
836
837
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
838
839
840
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
841
842
843
844
845
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
846
847
848
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
849
850
851
852
853

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
854

855
856
857
858
859
860
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
861
862
863
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
864
            )
865
            spec_accept_length = 0
866
        else:
867
            spec_accept_length = (
868
869
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
870
871
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
872
873
874
875
876
877
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
878
                f"accept len: {spec_accept_length:.2f}, "
879
880
881
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
882
883
884
            )

        logger.info(msg)
885
886
887
888
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
889
890
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
891
            self.stats.num_queue_reqs = len(self.waiting_queue)
892
            self.stats.spec_accept_length = spec_accept_length
893
894
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
895
896
897
898
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
899
900
901
902
903
904
905
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
906
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
907
                "KV cache pool leak detected!"
908
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
909
            )
910
911
912
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
913
914

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
915
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
916
                "Memory pool leak detected!"
917
918
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
919
            )
920
921
922
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
923

924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )
            num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

942
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
943
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
944
        if self.last_batch and self.last_batch.forward_mode.is_extend():
945
946
947
948
949
950
951
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
952
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
953

954
            self.last_batch.filter_batch()
955
956
957
958
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
959
                    # merge running_batch with prefill batch
960
                    self.running_batch.merge_batch(self.last_batch)
961

962
963
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
964
965
966
967
968
969
970
971
972
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
            if self.running_batch is None:
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch
973

974
975
976
977
978
        # Handle DP attention
        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret
979

Lianmin Zheng's avatar
Lianmin Zheng committed
980
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
981
        # Check if the grammar is ready in the grammar queue
982
        if self.grammar_queue:
983
            self.move_ready_grammar_requests()
984

Lianmin Zheng's avatar
Lianmin Zheng committed
985
986
987
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
988
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
989
990
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
991
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
992
        if running_bs >= self.max_running_requests:
993
            self.batch_is_full = True
994
995
996
997
998
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
999
        # Prefill policy
1000
1001
        adder = PrefillAdder(
            self.tree_cache,
1002
            self.token_to_kv_pool,
1003
1004
1005
1006
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1007
            running_bs if self.is_mixed_chunk else 0,
1008
1009
        )

1010
1011
1012
1013
        is_chunked = self.chunked_req is not None
        if is_chunked:
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1014

Lianmin Zheng's avatar
Lianmin Zheng committed
1015
        if self.lora_paths:
1016
1017
1018
1019
1020
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )
1021
        # Get requests from the waiting queue to a new prefill batch
1022
1023
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1024
                self.lora_paths
1025
1026
1027
1028
1029
1030
1031
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
1032
                self.batch_is_full = True
1033
1034
                break

1035
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1036
                self.batch_is_full = True
1037
                break
1038

1039
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063

            if self.enable_hierarchical_cache and req.last_node is not None:
                if req.last_node.evicted:
                    # loading KV cache for the request
                    req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
                        req.last_node,
                        req.prefix_indices,
                        adder.rem_total_tokens,
                    )
                    if req.last_node.loading:
                        # to prevent frequent cache invalidation
                        if req.rid in self.staging_reqs:
                            self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                        self.tree_cache.inc_lock_ref(req.last_node)
                        self.staging_reqs[req.rid] = req.last_node
                        continue
                elif req.last_node.loading:
                    if not self.tree_cache.loading_complete(req.last_node):
                        continue

                if req.rid in self.staging_reqs:
                    self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                    del self.staging_reqs[req.rid]

1064
            res = adder.add_one_req(req, self.chunked_req)
1065
1066
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1067
1068
1069
1070
1071
1072
1073
1074
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
                        self.batch_is_full = len(adder.can_run_list) > 0 or (
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
                        self.batch_is_full = True
1075
                break
1076
            if self.prefill_only_one_req:
1077
                break
1078

Lianmin Zheng's avatar
Lianmin Zheng committed
1079
        # Update waiting queue
1080
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1081
1082
1083
1084
1085
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1086

1087
1088
1089
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1090

1091
1092
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1093

1094
        # Print stats
1095
        if self.attn_tp_rank == 0:
1096
            self.log_prefill_stats(adder, can_run_list, running_bs)
1097

Lianmin Zheng's avatar
Lianmin Zheng committed
1098
        # Create a new batch
1099
1100
1101
1102
1103
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
1104
            self.model_config,
1105
            self.enable_overlap,
1106
            self.spec_algorithm,
1107
            self.server_args.enable_custom_logit_processor,
1108
        )
1109
        new_batch.prepare_for_extend()
1110

Lianmin Zheng's avatar
Lianmin Zheng committed
1111
        # Mixed-style chunked prefill
1112
1113
1114
1115
1116
1117
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1118
1119
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1120
                self.running_batch.prepare_for_decode()
1121
1122
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
1123
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1124
1125
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
1127
1128

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1129
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1130
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1131
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1132

1133
1134
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1135
1136
            self.batch_is_full = False
            return None
1137

Lianmin Zheng's avatar
Lianmin Zheng committed
1138
        # Check if decode out of memory
1139
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1140
            TEST_RETRACT and batch.batch_size() > 10
1141
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1142
1143
            old_ratio = self.new_token_ratio

1144
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1145
            self.new_token_ratio = new_token_ratio
1146
1147
            if self.draft_worker:
                self.draft_worker.finish_request(retracted_reqs)
1148

Lianmin Zheng's avatar
Lianmin Zheng committed
1149
1150
1151
1152
1153
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1154
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1155
1156
        else:
            self.new_token_ratio = max(
1157
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1158
1159
1160
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1161
1162
        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1163
1164

        # Update batch tensors
1165
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1166
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1167

1168
1169
1170
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1171
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1172
1173
        self.forward_ct += 1

1174
1175
1176
1177
1178
1179
1180
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1181
        if self.is_generation:
1182
1183
1184
1185
1186
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1187
            else:
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
                (
                    logits_output,
                    next_token_ids,
                    model_worker_batch,
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1199
            batch.output_ids = next_token_ids
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1212
1213
1214
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1215
1216
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1217
1218
                bid=model_worker_batch.bid,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1219
1220
1221
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1222
1223
1224
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1225
        return ret
Chayenne's avatar
Chayenne committed
1226

1227
1228
1229
1230
1231
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1232
1233
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1234
1235
            if batch.is_empty():
                self.running_batch = None
1236
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1237
            self.process_batch_result_prefill(batch, result)
1238
1239
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1240
                self.tp_worker.resolve_batch_result(result.bid)
1241
1242
1243
1244
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1245
1246
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1247
            self.current_stream.synchronize()
1248
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1249

1250
1251
1252
1253
1254
1255
1256
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1257
1258
1259
1260
1261
    def process_batch_result_prefill(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
1262
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1263

Lianmin Zheng's avatar
Lianmin Zheng committed
1264
        if self.is_generation:
1265
1266
1267
            (
                logits_output,
                next_token_ids,
1268
1269
                extend_input_len_per_req,
                extend_logprob_start_len_per_req,
1270
1271
1272
1273
                bid,
            ) = (
                result.logits_output,
                result.next_token_ids,
1274
1275
                result.extend_input_len_per_req,
                result.extend_logprob_start_len_per_req,
1276
1277
                result.bid,
            )
1278
1279

            if self.enable_overlap:
1280
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1281
1282
            else:
                # Move next_token_ids and logprobs to cpu
1283
                next_token_ids = next_token_ids.tolist()
1284
                if batch.return_logprob:
1285
1286
1287
1288
1289
1290
1291
1292
                    if logits_output.next_token_logprobs is not None:
                        logits_output.next_token_logprobs = (
                            logits_output.next_token_logprobs.tolist()
                        )
                    if logits_output.input_token_logprobs is not None:
                        logits_output.input_token_logprobs = tuple(
                            logits_output.input_token_logprobs.tolist()
                        )
1293

1294
1295
            hidden_state_offset = 0

1296
1297
            # Check finish conditions
            logprob_pt = 0
1298
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1299
1300
1301
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1302
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1303
1304
1305
1306
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1307

1308
1309
                if req.is_chunked <= 0:
                    # req output_ids are set here
1310
                    req.output_ids.append(next_token_id)
1311
1312
                    req.check_finished()

1313
                    if req.finished():
1314
                        self.tree_cache.cache_finished_req(req)
1315
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1316
                        # This updates radix so others can match
1317
1318
                        self.tree_cache.cache_unfinished_req(req)

1319
                    if req.return_logprob:
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
                        assert extend_logprob_start_len_per_req is not None
                        assert extend_input_len_per_req is not None
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        num_input_logprobs = extend_input_len - extend_logprob_start_len
                        self.add_logprob_return_values(
                            i,
                            req,
                            logprob_pt,
                            next_token_ids,
                            num_input_logprobs,
                            logits_output,
1332
                        )
1333
1334
                        logprob_pt += num_input_logprobs

1335
                    if (
1336
                        req.return_hidden_states
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
                        and logits_output.hidden_states is not None
                    ):
                        req.hidden_states.append(
                            logits_output.hidden_states[
                                hidden_state_offset : (
                                    hidden_state_offset := hidden_state_offset
                                    + len(req.origin_input_ids)
                                )
                            ]
                            .cpu()
                            .clone()
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1350
1351
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1352
                        req.grammar.finished = req.finished()
1353
                else:
1354
                    # being chunked reqs' prefill is not finished
1355
                    req.is_chunked -= 1
1356
1357
1358
1359
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1360

1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
                    # Incrementally update input logprobs.
                    if req.return_logprob:
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        if extend_logprob_start_len < extend_input_len:
                            # Update input logprobs.
                            num_input_logprobs = (
                                extend_input_len - extend_logprob_start_len
                            )
                            self.add_input_logprob_return_values(
                                i,
                                req,
                                logits_output,
                                logprob_pt,
                                num_input_logprobs,
                                last_prefill_chunk=False,
                            )
                            logprob_pt += num_input_logprobs

1380
1381
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1382
                self.current_stream.synchronize()
1383
1384
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1385
        else:  # embedding or reward model
1386
            embeddings, bid = result.embeddings, result.bid
1387
            embeddings = embeddings.tolist()
1388
1389
1390

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1391
1392
1393
                if req.is_retracted:
                    continue

1394
                req.embedding = embeddings[i]
1395
                if req.is_chunked <= 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1396
                    # Dummy output token for embedding models
1397
1398
1399
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1400
1401
1402
1403
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1404
                else:
1405
                    # being chunked reqs' prefill is not finished
1406
                    req.is_chunked -= 1
1407

Lianmin Zheng's avatar
Lianmin Zheng committed
1408
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1409

1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
    def process_batch_result_decode(
        self,
        batch: ScheduleBatch,
        result: GenerationBatchResult,
    ):
        logits_output, next_token_ids, bid = (
            result.logits_output,
            result.next_token_ids,
            result.bid,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1420
1421
        self.num_generated_tokens += len(batch.reqs)

1422
        if self.enable_overlap:
1423
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1424
            next_token_logprobs = logits_output.next_token_logprobs
1425
1426
        else:
            next_token_ids = next_token_ids.tolist()
1427
1428
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1429

1430
1431
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1432
1433
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1434
1435
1436
            if req.is_retracted:
                continue

1437
            if self.enable_overlap and req.finished():
1438
                # Free the one delayed token
1439
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1440
1441
                continue

1442
1443
1444
1445
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1446
1447
            req.check_finished()
            if req.finished():
1448
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1449

1450
1451
            if req.return_logprob and batch.spec_algorithm.is_none():
                # speculative worker handles logprob in speculative decoding
Lianmin Zheng's avatar
Lianmin Zheng committed
1452
1453
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1454
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1455
                    req.output_top_logprobs_val.append(
1456
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1457
1458
                    )
                    req.output_top_logprobs_idx.append(
1459
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1460
                    )
1461
1462
1463
1464
1465
1466
1467
                if req.token_ids_logprob is not None:
                    req.output_token_ids_logprobs_val.append(
                        logits_output.next_token_token_ids_logprobs_val[i]
                    )
                    req.output_token_ids_logprobs_idx.append(
                        logits_output.next_token_token_ids_logprobs_idx[i]
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1468

1469
            if req.return_hidden_states and logits_output.hidden_states is not None:
1470
1471
                req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())

1472
            if req.grammar is not None and batch.spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1473
                req.grammar.accept_token(next_token_id)
1474
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1475

1476
1477
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1478
            self.current_stream.synchronize()
1479
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1480
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1481

1482
1483
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1484
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1485
        if (
1486
            self.attn_tp_rank == 0
Chayenne's avatar
Chayenne committed
1487
1488
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1489
            self.log_decode_stats()
1490

1491
    def add_input_logprob_return_values(
1492
1493
1494
1495
        self,
        i: int,
        req: Req,
        output: LogitsProcessorOutput,
1496
1497
1498
        logprob_pt: int,
        num_input_logprobs: int,
        last_prefill_chunk: bool,  # If True, it means prefill is finished.
1499
    ):
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
        """Incrementally add input logprobs to `req`.

        Args:
            i: The request index in a batch.
            req: The request. Input logprobs inside req are modified as a
                consequence of the API
            fill_ids: The prefill ids processed.
            output: Logit processor output that's used to compute input logprobs
            last_prefill_chunk: True if it is the last prefill (when chunked).
                Some of input logprob operation should only happen at the last
                prefill (e.g., computing input token logprobs).
        """
        assert output.input_token_logprobs is not None
        if req.input_token_logprobs is None:
            req.input_token_logprobs = []
        if req.temp_input_top_logprobs_val is None:
            req.temp_input_top_logprobs_val = []
        if req.temp_input_top_logprobs_idx is None:
            req.temp_input_top_logprobs_idx = []
        if req.temp_input_token_ids_logprobs_val is None:
            req.temp_input_token_ids_logprobs_val = []
        if req.temp_input_token_ids_logprobs_idx is None:
            req.temp_input_token_ids_logprobs_idx = []

        if req.input_token_logprobs_val is not None:
            # The input logprob has been already computed. It only happens
            # upon retract.
            if req.top_logprobs_num > 0:
                assert req.input_token_logprobs_val is not None
            return
1530

1531
1532
1533
1534
1535
1536
1537
        # Important for the performance.
        assert isinstance(output.input_token_logprobs, tuple)
        input_token_logprobs: Tuple[int] = output.input_token_logprobs
        input_token_logprobs = input_token_logprobs[
            logprob_pt : logprob_pt + num_input_logprobs
        ]
        req.input_token_logprobs.extend(input_token_logprobs)
1538

1539
1540
1541
        if req.top_logprobs_num > 0:
            req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
            req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
1542

1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        if req.token_ids_logprob is not None:
            req.temp_input_token_ids_logprobs_val.append(
                output.input_token_ids_logprobs_val[i]
            )
            req.temp_input_token_ids_logprobs_idx.append(
                output.input_token_ids_logprobs_idx[i]
            )

        if last_prefill_chunk:
            input_token_logprobs = req.input_token_logprobs
            req.input_token_logprobs = None
            assert req.input_token_logprobs_val is None
            assert req.input_token_logprobs_idx is None
            assert req.input_top_logprobs_val is None
            assert req.input_top_logprobs_idx is None

            # Compute input_token_logprobs_val
            # Always pad the first one with None.
            req.input_token_logprobs_val = [None]
            req.input_token_logprobs_val.extend(input_token_logprobs)
            # The last input logprob is for sampling, so just pop it out.
            req.input_token_logprobs_val.pop()

            # Compute input_token_logprobs_idx
            input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1568
1569
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1570
            input_token_logprobs_idx = [
1571
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1572
                for x in input_token_logprobs_idx
1573
            ]
1574
            req.input_token_logprobs_idx = input_token_logprobs_idx
1575

1576
1577
1578
            if req.top_logprobs_num > 0:
                req.input_top_logprobs_val = [None]
                req.input_top_logprobs_idx = [None]
1579
1580
1581
                assert len(req.temp_input_token_ids_logprobs_val) == len(
                    req.temp_input_token_ids_logprobs_idx
                )
1582
                for val, idx in zip(
1583
                    req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
                ):
                    req.input_top_logprobs_val.extend(val)
                    req.input_top_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_top_logprobs_val.pop()
                req.input_top_logprobs_idx.pop()
                req.temp_input_top_logprobs_idx = None
                req.temp_input_top_logprobs_val = None

            if req.token_ids_logprob is not None:
                req.input_token_ids_logprobs_val = [None]
                req.input_token_ids_logprobs_idx = [None]

                for val, idx in zip(
                    req.temp_input_token_ids_logprobs_val,
                    req.temp_input_token_ids_logprobs_idx,
                    strict=True,
                ):
                    req.input_token_ids_logprobs_val.extend(val)
                    req.input_token_ids_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_token_ids_logprobs_val.pop()
                req.input_token_ids_logprobs_idx.pop()
                req.temp_input_token_ids_logprobs_idx = None
                req.temp_input_token_ids_logprobs_val = None

            if req.return_logprob:
                relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
                assert len(req.input_token_logprobs_val) == relevant_tokens_len
                assert len(req.input_token_logprobs_idx) == relevant_tokens_len
                if req.top_logprobs_num > 0:
                    assert len(req.input_top_logprobs_val) == relevant_tokens_len
                    assert len(req.input_top_logprobs_idx) == relevant_tokens_len
                if req.token_ids_logprob is not None:
                    assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
                    assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len

    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        num_input_logprobs: int,
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1635

1636
1637
1638
        self.add_input_logprob_return_values(
            i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
        )
1639
1640

        if req.top_logprobs_num > 0:
1641
1642
            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1643

1644
1645
1646
1647
1648
1649
1650
1651
        if req.token_ids_logprob is not None:
            req.output_token_ids_logprobs_val.append(
                output.next_token_token_ids_logprobs_val[i]
            )
            req.output_token_ids_logprobs_idx.append(
                output.next_token_token_ids_logprobs_idx[i]
            )

1652
1653
        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1654
1655
1656
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1657
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1658
1659
1660
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1661
1662
        if self.is_generation:
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1663
1664
            decode_ids_list = []
            read_offsets = []
1665
            output_ids = []
1666

Lianmin Zheng's avatar
Lianmin Zheng committed
1667
1668
1669
1670
1671
1672
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []
1673
            spec_verify_ct = []
1674
            output_hidden_states = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
1685
1686
1687
1688
                input_token_ids_logprobs_val = []
                input_token_ids_logprobs_idx = []
                output_token_ids_logprobs_val = []
                output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1689
1690
1691
1692
1693
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
1694
1695
1696
1697
1698
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    input_token_ids_logprobs_val
                ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
                    output_token_ids_logprobs_idx
                ) = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1699
1700
1701
1702

            for req in reqs:
                if req is skip_req:
                    continue
1703

1704
1705
1706
1707
                # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
                if self.model_config.is_multimodal_gen and req.to_abort:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1708
1709
1710
1711
1712
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1713
1714
1715
1716
1717
1718
1719
                    # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
                    # always increase one-by-one.
                    or (
                        not req.stream
                        and len(req.output_ids) % 50 == 0
                        and not self.model_config.is_multimodal_gen
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1720
                ):
1721
1722
1723
                    if self.draft_worker and req.finished():
                        self.draft_worker.finish_request(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
1724
1725
1726
1727
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
1728
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1729
1730
1731
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1732
                    if self.skip_tokenizer_init:
1733
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1734
1735
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1736
1737
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1738
1739
1740
1741
1742
1743
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

1744
1745
1746
                    if not self.spec_algorithm.is_none():
                        spec_verify_ct.append(req.spec_verify_ct)

Lianmin Zheng's avatar
Lianmin Zheng committed
1747
1748
1749
1750
1751
1752
1753
1754
1755
                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
                        input_token_ids_logprobs_val.append(
                            req.input_token_ids_logprobs_val
                        )
                        input_token_ids_logprobs_idx.append(
                            req.input_token_ids_logprobs_idx
                        )
                        output_token_ids_logprobs_val.append(
                            req.output_token_ids_logprobs_val
                        )
                        output_token_ids_logprobs_idx.append(
                            req.output_token_ids_logprobs_idx
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1768

1769
1770
1771
                    if req.return_hidden_states:
                        if output_hidden_states is None:
                            output_hidden_states = []
1772
                        output_hidden_states.append(req.hidden_states)
1773

Lianmin Zheng's avatar
Lianmin Zheng committed
1774
1775
            # Send to detokenizer
            if rids:
1776
1777
                if self.model_config.is_multimodal_gen:
                    raise NotImplementedError()
1778
                self.send_to_detokenizer.send_pyobj(
1779
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1780
1781
                        rids,
                        finished_reasons,
1782
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1783
1784
                        decode_ids_list,
                        read_offsets,
1785
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1786
1787
1788
1789
1790
1791
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
1792
                        spec_verify_ct,
Lianmin Zheng's avatar
Lianmin Zheng committed
1793
1794
1795
1796
1797
1798
1799
1800
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
1801
1802
1803
1804
                        input_token_ids_logprobs_val,
                        input_token_ids_logprobs_idx,
                        output_token_ids_logprobs_val,
                        output_token_ids_logprobs_idx,
1805
                        output_hidden_states,
1806
1807
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1808
1809
1810
1811
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
1812
1813
1814
1815
1816
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
Lianmin Zheng's avatar
Lianmin Zheng committed
1817
1818
1819
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1820

1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
1847
                    (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1867
            self.spec_algorithm,
1868
            self.server_args.enable_custom_logit_processor,
1869
1870
1871
1872
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1883
        if self.server_args.enable_dp_attention:
1884
1885
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1886
        else:
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1900

1901
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1902
1903
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1904
1905
1906
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1907
    def flush_cache(self):
1908
        """Flush the memory pool and cache."""
1909
1910
1911
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
1912
1913
            self.cur_batch = None
            self.last_batch = None
1914
1915
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1916
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1917
                self.grammar_backend.reset()
1918
1919
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
1920
1921
1922
1923
1924
1925
1926
1927
1928

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
                self.draft_worker.model_runner.token_to_kv_pool.clear()

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1929
1930
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1996
1997
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1998
1999
2000
2001

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
2002
                if req.rid == recv_req.rid and not req.finished():
2003
2004
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
2005
2006
                    break

Chayenne's avatar
Chayenne committed
2007
2008
2009
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2010
2011
2012
2013
2014
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2015
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2016

2017
2018
2019
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2020
        return InitWeightsUpdateGroupReqOutput(success, message)
2021
2022

    def update_weights_from_distributed(
2023
2024
2025
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2026
2027
2028
2029
2030
2031
2032
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2033
        return UpdateWeightsFromDistributedReqOutput(success, message)
2034

2035
2036
2037
2038
2039
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2040
2041
2042
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2043
2044
        else:
            logger.error(message)
2045
        return UpdateWeightsFromTensorReqOutput(success, message)
2046

2047
2048
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2049
        return GetWeightsByNameReqOutput(parameter)
2050

2051
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2052
2053
2054
2055
2056
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2057
        return ReleaseMemoryOccupationReqOutput()
2058

2059
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2060
2061
2062
2063
2064
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2065
2066
2067
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
2068
2069
2070
2071
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
                recv_req.output_dir, recv_req.num_steps, recv_req.activities
            )
2072
        else:
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
    ) -> None:
        if self.torch_profiler_activities:
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
        self.torch_profiler_activities = activities
        logger.info(
            "Profiling starts. Traces will be saved to: %s",
            self.torch_profiler_output_dir,
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
                with_stack=True,
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2116

2117
2118
2119
2120
2121
2122
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2123
2124

    def stop_profile(self) -> None:
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
        if self.torch_profiler_activities is None:
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
                    str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
                )
            )

        if "MEM" in self.torch_profiler_activities:
            memory_profile_path = os.path.join(
                self.torch_profiler_trace_dir,
                str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2149
        )
2150
2151
2152
2153
2154
2155
2156
2157
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
        self.torch_profiler_activities = None

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
2158

2159
    def open_session(self, recv_req: OpenSessionReqInput):
2160
2161
2162
2163
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2164
            return OpenSessionReqOutput(session_id, False)
2165
        elif session_id is None:
2166
            logger.warning("session id is None, cannot open.")
2167
            return OpenSessionReqOutput(session_id, False)
2168
2169
2170
2171
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2172
            return OpenSessionReqOutput(session_id, True)
2173
2174
2175
2176
2177
2178
2179
2180
2181

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2182

2183
2184
2185
2186
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2201
2202
2203
2204
2205
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2206
    dp_rank: Optional[int],
2207
    pipe_writer,
2208
):
2209
2210
2211
    # Config the process
    # kill_itself_when_parent_died()  # This is disabled because it does not work for `--dp 2`
    setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
2212
    faulthandler.enable()
2213
    parent_process = psutil.Process().parent()
2214

2215
2216
2217
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2218

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2219
    # Configure the logger
2220
    if dp_rank is None:
2221
        prefix = f" TP{tp_rank}"
2222
    else:
2223
2224
        prefix = f" DP{dp_rank} TP{tp_rank}"
    configure_logger(server_args, prefix=prefix)
2225
    suppress_other_loggers()
2226

2227
    # Set cpu affinity to this gpu process
2228
2229
2230
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2231
    # Create a scheduler and run the event loop
2232
    try:
2233
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2234
        pipe_writer.send(
Mick's avatar
Mick committed
2235
2236
2237
2238
2239
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2240
        )
2241
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
2242
2243
2244
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
2245
    except Exception:
2246
2247
2248
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)