scheduler.py 89.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
35
import zmq

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
42
43
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
45
    BatchMultimodalDecodeReq,
46
    BatchTokenIDOut,
47
    CloseSessionReqInput,
48
    FlushCacheReq,
49
50
    GetInternalStateReq,
    GetInternalStateReqOutput,
51
52
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
53
    HealthCheckOutput,
54
55
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
56
57
    OpenSessionReqInput,
    OpenSessionReqOutput,
58
    ProfileReq,
59
60
    ProfileReqOutput,
    ProfileReqType,
61
62
63
64
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
65
66
    SetInternalStateReq,
    SetInternalStateReqOutput,
67
68
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
69
70
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
71
72
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
73
74
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
75
76
77
78
79
80
81
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
82
    global_server_args_dict,
83
)
84
85
86
87
88
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
89
from sglang.srt.managers.session_controller import Session
90
from sglang.srt.managers.tp_worker import TpModelWorker
91
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
92
from sglang.srt.managers.utils import validate_input_length
93
from sglang.srt.mem_cache.chunk_cache import ChunkCache
94
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
95
from sglang.srt.mem_cache.radix_cache import RadixCache
96
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
97
from sglang.srt.model_executor.forward_batch_info import ForwardMode
98
from sglang.srt.server_args import PortArgs, ServerArgs
99
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
100
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
101
102
103
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
104
    crash_on_warnings,
105
    get_bool_env_var,
106
    get_zmq_socket,
107
108
    kill_itself_when_parent_died,
    pyspy_dump_schedulers,
109
    set_gpu_proc_affinity,
110
111
112
    set_random_seed,
    suppress_other_loggers,
)
113
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
114
115
116

logger = logging.getLogger(__name__)

117
# Test retract decode for debugging purposes
118
119
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
120

121

122
123
124
125
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
126
127
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
128
129
130
131
132
133
134
135
136
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


137
138
139
140
141
142
143
144
145
class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
146
        dp_rank: Optional[int],
147
148
    ):
        # Parse args
149
        self.server_args = server_args
150
151
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
152
153
154
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
155
        self.enable_overlap = not server_args.disable_overlap_schedule
156
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
157
        self.enable_metrics = server_args.enable_metrics
158
        self.stream_interval = server_args.stream_interval
159
160
161
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
162
163
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
164
        self.decode_mem_cache_buf_multiplier = (
165
166
167
168
169
170
171
            (
                self.server_args.speculative_num_draft_tokens
                + (
                    self.server_args.speculative_eagle_topk
                    * self.server_args.speculative_num_steps
                )
            )
172
173
174
            if not self.spec_algorithm.is_none()
            else 1
        )
175

176
        # Distributed rank info
177
178
179
180
181
182
183
184
185
186
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

187
188
        # Init inter-process communication
        context = zmq.Context(2)
189
        if self.attn_tp_rank == 0:
190
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
191
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
192
            )
193
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
194
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
195
            )
196

197
            if server_args.skip_tokenizer_init:
198
                # Directly send to the TokenizerManager
199
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
200
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
201
202
                )
            else:
203
                # Send to the DetokenizerManager
204
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
205
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
206
                )
207
        else:
208
            self.recv_from_tokenizer = None
209
210
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
211
212
213
214

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
215
            trust_remote_code=server_args.trust_remote_code,
216
            revision=server_args.revision,
217
            context_length=server_args.context_length,
218
219
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
220
221
            dtype=server_args.dtype,
            quantization=server_args.quantization,
222
        )
223
        self.is_generation = self.model_config.is_generation
224
225
226
227

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
228
            if self.model_config.is_multimodal:
229
230
231
232
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
233
                    revision=server_args.revision,
234
235
236
237
238
239
240
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
241
                    revision=server_args.revision,
242
                )
243

244
245
246
247
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
248

249
250
251
252
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

253
        # Launch a tensor parallel worker
254
        if self.enable_overlap:
255
            TpWorkerClass = TpModelWorkerClient
256
257
        else:
            TpWorkerClass = TpModelWorker
258

259
        self.tp_worker = TpWorkerClass(
260
            server_args=server_args,
261
262
            gpu_id=gpu_id,
            tp_rank=tp_rank,
263
            dp_rank=dp_rank,
264
            nccl_port=port_args.nccl_port,
265
        )
266

267
        # Launch a draft worker for speculative decoding
268
269
270
271
272
273
274
275
276
277
278
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
279
            self.prefill_only_one_req = True
280
281
        else:
            self.draft_worker = None
282
            self.prefill_only_one_req = False
283

284
        # Get token and memory info from the model worker
285
286
287
288
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
289
            self.max_req_len,
290
291
            self.max_req_input_len,
            self.random_seed,
292
            self.device,
293
294
295
296
297
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
298
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
299
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
300
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
301
        global_server_args_dict.update(worker_global_server_args_dict)
302
        set_random_seed(self.random_seed)
303

304
305
306
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
307
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
308
309
310
311
312
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

313
314
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
315
316
317
318
319
320
321
322
323
324

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
325
326
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
327
328
329
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool,
                )
330
331
            else:
                self.tree_cache = RadixCache(
332
333
334
335
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool,
                    disable=server_args.disable_radix_cache,
                )
336

337
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
338
339
340

        # Init running status
        self.waiting_queue: List[Req] = []
341
        self.staging_reqs = {}
342
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
343
        self.running_batch: Optional[ScheduleBatch] = None
344
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
345
        self.cur_batch: Optional[ScheduleBatch] = None
346
347
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
        self.forward_ct = 0
        self.forward_ct_decode = 0
350
        self.num_generated_tokens = 0
351
352
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
353
354
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
355
        self.last_decode_stats_tic = time.time()
356
        self.return_health_check_ct = 0
357
        self.current_stream = torch.get_device_module(self.device).current_stream()
358
359
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
360

361
362
363
364
365
366
367
368
        # For metrics only.
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]

369
        # Session info
370
        self.sessions: Dict[str, Session] = {}
371
372
373

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
374
375
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
376
        self.chunked_req = None
377
378
379
380
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
381
        # Init the grammar backend for constrained generation
382
        self.grammar_queue: List[Req] = []
383
        if not server_args.skip_tokenizer_init:
384
385
386
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
        else:
            self.grammar_backend = None
389
390

        # Init new token estimation
391
392
393
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
394
395
396

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
397
398
            * server_args.schedule_conservativeness,
            1.0,
399
        )
400
401
402
403
404
405
406
407
408
409
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

410
        # Tell whether the current running batch is full so that we can skip
Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
413
        self.batch_is_full = False
414

Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
417
418
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
419
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
420

421
        # Init memory saver
422
423
424
425
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

426
        # Init profiler
427
428
429
430
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.torch_profiler_activities: Optional[List[str]] = None
        self.profiler_target_forward_ct: Optional[int] = None
431

432
        # Init metrics stats
433
434
435
436
437
438
439
440
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
441

442
443
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
444
445
446
447
448
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
449
450
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
451
452
453
454
455
456
457
458
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
459
460
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
461
                (ProfileReq, self.profile),
462
463
                (GetInternalStateReq, self.get_internal_state),
                (SetInternalStateReq, self.set_internal_state),
464
465
466
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
467
    def watchdog_thread(self):
468
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
469
470
471
472
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
473
            current = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
474
475
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
476
                    if current > self.watchdog_last_time + self.watchdog_timeout:
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
480
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
481
482
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)
483
484
485
486
487
488
489
490
491
492
493
494

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
495
        time.sleep(5)
496
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
497

498
    @torch.no_grad()
499
    def event_loop_normal(self):
500
        """A normal scheduler loop."""
501
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
502
503
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
504

505
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
506
            self.cur_batch = batch
507
508
509
510

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
511
            else:
512
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
513
                self.check_memory()
514
                self.new_token_ratio = self.init_new_token_ratio
515
516

            self.last_batch = batch
517

518
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
519
    def event_loop_overlap(self):
520
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
521
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
522
523
524
525
526
527
528

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
529

Lianmin Zheng's avatar
Lianmin Zheng committed
530
531
            if batch:
                result = self.run_batch(batch)
532
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
533

534
                if self.last_batch is None:
535
                    # Create a dummy first batch to start the pipeline for overlap schedule.
536
537
538
539
540
541
542
543
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
544
            if self.last_batch:
545
                # Process the results of the last batch
546
                tmp_batch, tmp_result = self.result_queue.popleft()
547
548
549
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
550
551
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
552
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
553
                self.check_memory()
554
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556
557

            self.last_batch = batch

558
559
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
560
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
            recv_reqs = []

563
564
565
566
567
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
568
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
569
570
        else:
            recv_reqs = None
571

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
606
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
607
608
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
609
    def process_input_requests(self, recv_reqs: List):
610
        for recv_req in recv_reqs:
611
612
613
614
615
616
617
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
                self.chunked_req is not None or self.running_batch is not None
            ):
                self.return_health_check_ct += 1
                continue

618
            output = self._request_dispatcher(recv_req)
619
620
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
621
622
623
624
625

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
626
        # Create a new request
627
628
629
630
631
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
632
633
634
635
636
637
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

638
639
640
641
642
643
644
645
646
647
648
649
650
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

651
652
653
654
655
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
656
657
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
658
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
659
                stream=recv_req.stream,
660
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
661
                input_embeds=recv_req.input_embeds,
662
                custom_logit_processor=custom_logit_processor,
663
                return_hidden_states=recv_req.return_hidden_states,
664
                eos_token_ids=self.model_config.hf_eos_token_id,
665
666
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
667

668
669
670
671
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
672
                req.finished_reason = FINISH_ABORT(
673
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
674
                )
675
                self._add_request_to_queue(req)
676
677
                return
        else:
678
679
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
680
            req = session.create_req(recv_req, self.tokenizer)
681
            if isinstance(req.finished_reason, FINISH_ABORT):
682
                self._add_request_to_queue(req)
683
                return
684

685
        # Handle multimodal inputs
686
        if recv_req.image_inputs is not None:
687
688
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
689
            req.origin_input_ids = self.pad_input_ids_func(
690
                req.origin_input_ids, image_inputs
691
            )
692
            req.extend_image_inputs(image_inputs)
693

694
            if len(req.origin_input_ids) >= self.max_req_input_len:
695
                error_msg = (
696
                    "Multimodal prompt is too long after expanding multimodal tokens. "
697
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
698
                )
699
                logger.error(error_msg)
700
                req.origin_input_ids = [0]
701
                req.image_inputs = None
702
                req.sampling_params.max_new_tokens = 0
703
                req.finished_reason = FINISH_ABORT(
704
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
705
                )
706
                self._add_request_to_queue(req)
707
708
                return

709
710
711
712
713
714
715
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
716
717
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
718
            self._add_request_to_queue(req)
719
            return
720

721
        # Copy more attributes
722
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
723
724
725
726
727
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

728
729
730
731
732
733
734
735
736
737
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

738
739
740
741
742
743
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
744
            self.max_req_len - len(req.origin_input_ids) - 1,
745
746
        )

747
748
749
750
751
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
752
            or req.sampling_params.ebnf is not None
753
            or req.sampling_params.structural_tag is not None
754
755
756
757
758
759
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
760
761
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
762
763
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
764
765
766
767
768
769
770

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
771
772
            self.grammar_queue.append(req)
        else:
773
774
775
776
777
778
779
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
        self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req]):
        self.waiting_queue.extend(reqs)
780
781
782

    def handle_embedding_request(
        self,
783
        recv_req: TokenizedEmbeddingReqInput,
784
785
786
787
788
789
790
791
792
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

793
        # Validate prompts length
794
        error_msg = validate_input_length(
795
796
797
798
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
799
        if error_msg:
800
            self._add_request_to_queue(req)
801
            return
802

803
804
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
805
        self._add_request_to_queue(req)
806

807
808
809
810
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
811
        running_bs: int,
812
    ):
813
814
815
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
816
817
818
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
819

820
        f = (
821
822
823
824
825
826
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
827
            f"#queue-req: {len(self.waiting_queue)}, "
828
        )
829
        logger.info(f)
830
831

        if self.enable_metrics:
832
833
834
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
835
836
837
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
838
839
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
840
841
842
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
843
844
845
846
847
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
848
849
850
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
851
852
853
854
855

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
856

857
858
859
860
861
862
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
863
864
865
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
866
            )
867
            spec_accept_length = 0
868
        else:
869
            spec_accept_length = (
870
871
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
872
873
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
874
875
876
877
878
879
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
880
                f"accept len: {spec_accept_length:.2f}, "
881
882
883
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"largest-len: {self._largest_prefill_decode_len}, "
                f"#queue-req: {len(self.waiting_queue)}, "
884
885
886
            )

        logger.info(msg)
887
888
889
890
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
891
892
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
893
            self.stats.num_queue_reqs = len(self.waiting_queue)
894
            self.stats.spec_accept_length = spec_accept_length
895
896
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
897
898
899
900
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
901
902
903
904
905
906
907
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
908
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
909
                "KV cache pool leak detected!"
910
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
911
            )
912
913
914
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
915
916

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
917
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
918
                "Memory pool leak detected!"
919
920
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
921
            )
922
923
924
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
925

926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )
            num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

944
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
945
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
946
        if self.last_batch and self.last_batch.forward_mode.is_extend():
947
948
949
950
951
952
953
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
954
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
955

956
            self.last_batch.filter_batch()
957
958
959
960
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
961
                    # merge running_batch with prefill batch
962
                    self.running_batch.merge_batch(self.last_batch)
963

964
965
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
966
967
968
969
970
971
972
973
974
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
            if self.running_batch is None:
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch
975

976
977
978
979
980
        # Handle DP attention
        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret
981

Lianmin Zheng's avatar
Lianmin Zheng committed
982
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
983
        # Check if the grammar is ready in the grammar queue
984
        if self.grammar_queue:
985
            self.move_ready_grammar_requests()
986

Lianmin Zheng's avatar
Lianmin Zheng committed
987
988
989
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
990
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
991
992
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
993
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
994
        if running_bs >= self.max_running_requests:
995
            self.batch_is_full = True
996
997
998
999
1000
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1001
        # Prefill policy
1002
1003
        adder = PrefillAdder(
            self.tree_cache,
1004
            self.token_to_kv_pool,
1005
1006
1007
1008
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1009
            running_bs if self.is_mixed_chunk else 0,
1010
1011
        )

1012
1013
1014
1015
        is_chunked = self.chunked_req is not None
        if is_chunked:
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1016

Lianmin Zheng's avatar
Lianmin Zheng committed
1017
        if self.lora_paths:
1018
1019
1020
1021
1022
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )
1023
        # Get requests from the waiting queue to a new prefill batch
1024
1025
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1026
                self.lora_paths
1027
1028
1029
1030
1031
1032
1033
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
1034
                self.batch_is_full = True
1035
1036
                break

1037
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1038
                self.batch_is_full = True
1039
                break
1040

1041
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065

            if self.enable_hierarchical_cache and req.last_node is not None:
                if req.last_node.evicted:
                    # loading KV cache for the request
                    req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
                        req.last_node,
                        req.prefix_indices,
                        adder.rem_total_tokens,
                    )
                    if req.last_node.loading:
                        # to prevent frequent cache invalidation
                        if req.rid in self.staging_reqs:
                            self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                        self.tree_cache.inc_lock_ref(req.last_node)
                        self.staging_reqs[req.rid] = req.last_node
                        continue
                elif req.last_node.loading:
                    if not self.tree_cache.loading_complete(req.last_node):
                        continue

                if req.rid in self.staging_reqs:
                    self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                    del self.staging_reqs[req.rid]

1066
            res = adder.add_one_req(req, self.chunked_req)
1067
1068
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1069
1070
1071
1072
1073
1074
1075
1076
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
                        self.batch_is_full = len(adder.can_run_list) > 0 or (
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
                        self.batch_is_full = True
1077
                break
1078
            if self.prefill_only_one_req:
1079
                break
1080

Lianmin Zheng's avatar
Lianmin Zheng committed
1081
        # Update waiting queue
1082
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1083
1084
1085
1086
1087
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1088

1089
1090
1091
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1092

1093
1094
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1095

1096
        # Print stats
1097
        if self.attn_tp_rank == 0:
1098
            self.log_prefill_stats(adder, can_run_list, running_bs)
1099

Lianmin Zheng's avatar
Lianmin Zheng committed
1100
        # Create a new batch
1101
1102
1103
1104
1105
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
1106
            self.model_config,
1107
            self.enable_overlap,
1108
            self.spec_algorithm,
1109
            self.server_args.enable_custom_logit_processor,
1110
        )
1111
        new_batch.prepare_for_extend()
1112

Lianmin Zheng's avatar
Lianmin Zheng committed
1113
        # Mixed-style chunked prefill
1114
1115
1116
1117
1118
1119
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1120
1121
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1122
                self.running_batch.prepare_for_decode()
1123
1124
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
1125
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
1127
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1131
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1132
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1134

1135
1136
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1137
1138
            self.batch_is_full = False
            return None
1139

Lianmin Zheng's avatar
Lianmin Zheng committed
1140
        # Check if decode out of memory
1141
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1142
            TEST_RETRACT and batch.batch_size() > 10
1143
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1144
1145
            old_ratio = self.new_token_ratio

1146
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1147
            self.new_token_ratio = new_token_ratio
1148
1149
            if self.draft_worker:
                self.draft_worker.finish_request(retracted_reqs)
1150

Lianmin Zheng's avatar
Lianmin Zheng committed
1151
1152
1153
1154
1155
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1156
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1157
1158
        else:
            self.new_token_ratio = max(
1159
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1160
1161
1162
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1163
1164
        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1165
1166

        # Update batch tensors
1167
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1168
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1169

1170
1171
1172
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1173
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1174
1175
        self.forward_ct += 1

1176
1177
1178
1179
1180
1181
1182
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1183
        if self.is_generation:
1184
1185
1186
1187
1188
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1189
            else:
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
                (
                    logits_output,
                    next_token_ids,
                    model_worker_batch,
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1201
            batch.output_ids = next_token_ids
1202

1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1215
1216
1217
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1218
1219
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1220
1221
                bid=model_worker_batch.bid,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1222
1223
1224
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1225
1226
1227
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1228
        return ret
Chayenne's avatar
Chayenne committed
1229

1230
1231
1232
1233
1234
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
1236
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1237
1238
            if batch.is_empty():
                self.running_batch = None
1239
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1240
            self.process_batch_result_prefill(batch, result)
1241
1242
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1243
                self.tp_worker.resolve_batch_result(result.bid)
1244
1245
1246
1247
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1248
1249
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1250
            self.current_stream.synchronize()
1251
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1252

1253
1254
1255
1256
1257
1258
1259
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1260
1261
1262
1263
1264
    def process_batch_result_prefill(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
1265
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1266

Lianmin Zheng's avatar
Lianmin Zheng committed
1267
        if self.is_generation:
1268
1269
1270
            (
                logits_output,
                next_token_ids,
1271
1272
                extend_input_len_per_req,
                extend_logprob_start_len_per_req,
1273
1274
1275
1276
                bid,
            ) = (
                result.logits_output,
                result.next_token_ids,
1277
1278
                result.extend_input_len_per_req,
                result.extend_logprob_start_len_per_req,
1279
1280
                result.bid,
            )
1281
1282

            if self.enable_overlap:
1283
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1284
1285
            else:
                # Move next_token_ids and logprobs to cpu
1286
                next_token_ids = next_token_ids.tolist()
1287
                if batch.return_logprob:
1288
1289
1290
1291
1292
1293
1294
1295
                    if logits_output.next_token_logprobs is not None:
                        logits_output.next_token_logprobs = (
                            logits_output.next_token_logprobs.tolist()
                        )
                    if logits_output.input_token_logprobs is not None:
                        logits_output.input_token_logprobs = tuple(
                            logits_output.input_token_logprobs.tolist()
                        )
1296

1297
1298
            hidden_state_offset = 0

1299
1300
            # Check finish conditions
            logprob_pt = 0
1301
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1302
1303
1304
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1305
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1306
1307
1308
1309
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1310

1311
1312
                if req.is_chunked <= 0:
                    # req output_ids are set here
1313
                    req.output_ids.append(next_token_id)
1314
1315
                    req.check_finished()

1316
                    if req.finished():
1317
                        self.tree_cache.cache_finished_req(req)
1318
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1319
                        # This updates radix so others can match
1320
1321
                        self.tree_cache.cache_unfinished_req(req)

1322
                    if req.return_logprob:
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
                        assert extend_logprob_start_len_per_req is not None
                        assert extend_input_len_per_req is not None
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        num_input_logprobs = extend_input_len - extend_logprob_start_len
                        self.add_logprob_return_values(
                            i,
                            req,
                            logprob_pt,
                            next_token_ids,
                            num_input_logprobs,
                            logits_output,
1335
                        )
1336
1337
                        logprob_pt += num_input_logprobs

1338
                    if (
1339
                        req.return_hidden_states
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
                        and logits_output.hidden_states is not None
                    ):
                        req.hidden_states.append(
                            logits_output.hidden_states[
                                hidden_state_offset : (
                                    hidden_state_offset := hidden_state_offset
                                    + len(req.origin_input_ids)
                                )
                            ]
                            .cpu()
                            .clone()
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1353
1354
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1355
                        req.grammar.finished = req.finished()
1356
                else:
1357
                    # being chunked reqs' prefill is not finished
1358
                    req.is_chunked -= 1
1359
1360
1361
1362
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1363

1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
                    # Incrementally update input logprobs.
                    if req.return_logprob:
                        extend_logprob_start_len = extend_logprob_start_len_per_req[i]
                        extend_input_len = extend_input_len_per_req[i]
                        if extend_logprob_start_len < extend_input_len:
                            # Update input logprobs.
                            num_input_logprobs = (
                                extend_input_len - extend_logprob_start_len
                            )
                            self.add_input_logprob_return_values(
                                i,
                                req,
                                logits_output,
                                logprob_pt,
                                num_input_logprobs,
                                last_prefill_chunk=False,
                            )
                            logprob_pt += num_input_logprobs

1383
1384
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1385
                self.current_stream.synchronize()
1386
1387
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1388
        else:  # embedding or reward model
1389
            embeddings, bid = result.embeddings, result.bid
1390
            embeddings = embeddings.tolist()
1391
1392
1393

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1394
1395
1396
                if req.is_retracted:
                    continue

1397
                req.embedding = embeddings[i]
1398
                if req.is_chunked <= 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1399
                    # Dummy output token for embedding models
1400
1401
1402
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1403
1404
1405
1406
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1407
                else:
1408
                    # being chunked reqs' prefill is not finished
1409
                    req.is_chunked -= 1
1410

Lianmin Zheng's avatar
Lianmin Zheng committed
1411
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1412

1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
    def process_batch_result_decode(
        self,
        batch: ScheduleBatch,
        result: GenerationBatchResult,
    ):
        logits_output, next_token_ids, bid = (
            result.logits_output,
            result.next_token_ids,
            result.bid,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1423
1424
        self.num_generated_tokens += len(batch.reqs)

1425
        if self.enable_overlap:
1426
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1427
            next_token_logprobs = logits_output.next_token_logprobs
1428
1429
        else:
            next_token_ids = next_token_ids.tolist()
1430
1431
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1432

1433
1434
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1435
1436
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1437
1438
1439
            if req.is_retracted:
                continue

1440
            if self.enable_overlap and req.finished():
1441
                # Free the one delayed token
1442
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1443
1444
                continue

1445
1446
1447
1448
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1449
1450
            req.check_finished()
            if req.finished():
1451
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1452

1453
1454
            if req.return_logprob and batch.spec_algorithm.is_none():
                # speculative worker handles logprob in speculative decoding
Lianmin Zheng's avatar
Lianmin Zheng committed
1455
1456
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1457
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1458
                    req.output_top_logprobs_val.append(
1459
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1460
1461
                    )
                    req.output_top_logprobs_idx.append(
1462
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
                    )
1464
1465
1466
1467
1468
1469
1470
                if req.token_ids_logprob is not None:
                    req.output_token_ids_logprobs_val.append(
                        logits_output.next_token_token_ids_logprobs_val[i]
                    )
                    req.output_token_ids_logprobs_idx.append(
                        logits_output.next_token_token_ids_logprobs_idx[i]
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1471

1472
            if req.return_hidden_states and logits_output.hidden_states is not None:
1473
1474
                req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())

1475
            if req.grammar is not None and batch.spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1476
                req.grammar.accept_token(next_token_id)
1477
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1478

1479
1480
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1481
            self.current_stream.synchronize()
1482
1483
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1484
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1485

1486
1487
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1488
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1489
        if (
1490
            self.attn_tp_rank == 0
Chayenne's avatar
Chayenne committed
1491
1492
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1493
            self.log_decode_stats()
1494

1495
    def add_input_logprob_return_values(
1496
1497
1498
1499
        self,
        i: int,
        req: Req,
        output: LogitsProcessorOutput,
1500
1501
1502
        logprob_pt: int,
        num_input_logprobs: int,
        last_prefill_chunk: bool,  # If True, it means prefill is finished.
1503
    ):
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
        """Incrementally add input logprobs to `req`.

        Args:
            i: The request index in a batch.
            req: The request. Input logprobs inside req are modified as a
                consequence of the API
            fill_ids: The prefill ids processed.
            output: Logit processor output that's used to compute input logprobs
            last_prefill_chunk: True if it is the last prefill (when chunked).
                Some of input logprob operation should only happen at the last
                prefill (e.g., computing input token logprobs).
        """
        assert output.input_token_logprobs is not None
        if req.input_token_logprobs is None:
            req.input_token_logprobs = []
        if req.temp_input_top_logprobs_val is None:
            req.temp_input_top_logprobs_val = []
        if req.temp_input_top_logprobs_idx is None:
            req.temp_input_top_logprobs_idx = []
        if req.temp_input_token_ids_logprobs_val is None:
            req.temp_input_token_ids_logprobs_val = []
        if req.temp_input_token_ids_logprobs_idx is None:
            req.temp_input_token_ids_logprobs_idx = []

        if req.input_token_logprobs_val is not None:
            # The input logprob has been already computed. It only happens
            # upon retract.
            if req.top_logprobs_num > 0:
                assert req.input_token_logprobs_val is not None
            return
1534

1535
1536
1537
1538
1539
1540
1541
        # Important for the performance.
        assert isinstance(output.input_token_logprobs, tuple)
        input_token_logprobs: Tuple[int] = output.input_token_logprobs
        input_token_logprobs = input_token_logprobs[
            logprob_pt : logprob_pt + num_input_logprobs
        ]
        req.input_token_logprobs.extend(input_token_logprobs)
1542

1543
1544
1545
        if req.top_logprobs_num > 0:
            req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
            req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
Lianmin Zheng's avatar
Lianmin Zheng committed
1546

1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
        if req.token_ids_logprob is not None:
            req.temp_input_token_ids_logprobs_val.append(
                output.input_token_ids_logprobs_val[i]
            )
            req.temp_input_token_ids_logprobs_idx.append(
                output.input_token_ids_logprobs_idx[i]
            )

        if last_prefill_chunk:
            input_token_logprobs = req.input_token_logprobs
            req.input_token_logprobs = None
            assert req.input_token_logprobs_val is None
            assert req.input_token_logprobs_idx is None
            assert req.input_top_logprobs_val is None
            assert req.input_top_logprobs_idx is None

            # Compute input_token_logprobs_val
            # Always pad the first one with None.
            req.input_token_logprobs_val = [None]
            req.input_token_logprobs_val.extend(input_token_logprobs)
            # The last input logprob is for sampling, so just pop it out.
            req.input_token_logprobs_val.pop()

            # Compute input_token_logprobs_idx
            input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1572
1573
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1574
            input_token_logprobs_idx = [
1575
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1576
                for x in input_token_logprobs_idx
1577
            ]
1578
            req.input_token_logprobs_idx = input_token_logprobs_idx
1579

1580
1581
1582
            if req.top_logprobs_num > 0:
                req.input_top_logprobs_val = [None]
                req.input_top_logprobs_idx = [None]
Lianmin Zheng's avatar
Lianmin Zheng committed
1583

1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
                for val, idx in zip(
                    req.temp_input_top_logprobs_val,
                    req.temp_input_top_logprobs_idx,
                    strict=True,
                ):
                    req.input_top_logprobs_val.extend(val)
                    req.input_top_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_top_logprobs_val.pop()
                req.input_top_logprobs_idx.pop()
                req.temp_input_top_logprobs_idx = None
                req.temp_input_top_logprobs_val = None

            if req.token_ids_logprob is not None:
                req.input_token_ids_logprobs_val = [None]
                req.input_token_ids_logprobs_idx = [None]

                for val, idx in zip(
                    req.temp_input_token_ids_logprobs_val,
                    req.temp_input_token_ids_logprobs_idx,
                    strict=True,
                ):
                    req.input_token_ids_logprobs_val.extend(val)
                    req.input_token_ids_logprobs_idx.extend(idx)

                # Last token is a sample token.
                req.input_token_ids_logprobs_val.pop()
                req.input_token_ids_logprobs_idx.pop()
                req.temp_input_token_ids_logprobs_idx = None
                req.temp_input_token_ids_logprobs_val = None

            if req.return_logprob:
                relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
                assert len(req.input_token_logprobs_val) == relevant_tokens_len
                assert len(req.input_token_logprobs_idx) == relevant_tokens_len
                if req.top_logprobs_num > 0:
                    assert len(req.input_top_logprobs_val) == relevant_tokens_len
                    assert len(req.input_top_logprobs_idx) == relevant_tokens_len
                if req.token_ids_logprob is not None:
                    assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
                    assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len

    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        num_input_logprobs: int,
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1639

1640
1641
1642
        self.add_input_logprob_return_values(
            i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
        )
1643
1644

        if req.top_logprobs_num > 0:
1645
1646
            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1647

1648
1649
1650
1651
1652
1653
1654
1655
        if req.token_ids_logprob is not None:
            req.output_token_ids_logprobs_val.append(
                output.next_token_token_ids_logprobs_val[i]
            )
            req.output_token_ids_logprobs_idx.append(
                output.next_token_token_ids_logprobs_idx[i]
            )

1656
1657
        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1658
1659
1660
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1661
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1662
1663
1664
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1665
1666
        if self.is_generation:
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1667
1668
            decode_ids_list = []
            read_offsets = []
1669
            output_ids = []
1670

Lianmin Zheng's avatar
Lianmin Zheng committed
1671
1672
1673
1674
1675
1676
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []
1677
            spec_verify_ct = []
1678
            output_hidden_states = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
1689
1690
1691
1692
                input_token_ids_logprobs_val = []
                input_token_ids_logprobs_idx = []
                output_token_ids_logprobs_val = []
                output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1693
1694
1695
1696
1697
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
1698
1699
1700
1701
1702
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    input_token_ids_logprobs_val
                ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
                    output_token_ids_logprobs_idx
                ) = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1703
1704
1705
1706

            for req in reqs:
                if req is skip_req:
                    continue
1707

1708
1709
1710
1711
                # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
                if self.model_config.is_multimodal_gen and req.to_abort:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1712
1713
1714
1715
1716
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1717
1718
1719
1720
1721
1722
1723
                    # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
                    # always increase one-by-one.
                    or (
                        not req.stream
                        and len(req.output_ids) % 50 == 0
                        and not self.model_config.is_multimodal_gen
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1724
                ):
1725
1726
1727
                    if self.draft_worker and req.finished():
                        self.draft_worker.finish_request(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
1730
1731
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
1732
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1733
1734
1735
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1736
                    if self.skip_tokenizer_init:
1737
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1738
1739
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1740
1741
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1742
1743
1744
1745
1746
1747
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

1748
1749
1750
                    if not self.spec_algorithm.is_none():
                        spec_verify_ct.append(req.spec_verify_ct)

Lianmin Zheng's avatar
Lianmin Zheng committed
1751
1752
1753
1754
1755
1756
1757
1758
1759
                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
                        input_token_ids_logprobs_val.append(
                            req.input_token_ids_logprobs_val
                        )
                        input_token_ids_logprobs_idx.append(
                            req.input_token_ids_logprobs_idx
                        )
                        output_token_ids_logprobs_val.append(
                            req.output_token_ids_logprobs_val
                        )
                        output_token_ids_logprobs_idx.append(
                            req.output_token_ids_logprobs_idx
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1772

1773
1774
1775
                    if req.return_hidden_states:
                        if output_hidden_states is None:
                            output_hidden_states = []
1776
                        output_hidden_states.append(req.hidden_states)
1777

Lianmin Zheng's avatar
Lianmin Zheng committed
1778
1779
            # Send to detokenizer
            if rids:
1780
1781
1782
                if self.model_config.is_multimodal_gen:
                    raise NotImplementedError()

1783
                self.send_to_detokenizer.send_pyobj(
1784
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1785
1786
                        rids,
                        finished_reasons,
1787
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1788
1789
                        decode_ids_list,
                        read_offsets,
1790
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
1792
1793
1794
1795
1796
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
1797
                        spec_verify_ct,
Lianmin Zheng's avatar
Lianmin Zheng committed
1798
1799
1800
1801
1802
1803
1804
1805
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
1806
1807
1808
1809
                        input_token_ids_logprobs_val,
                        input_token_ids_logprobs_idx,
                        output_token_ids_logprobs_val,
                        output_token_ids_logprobs_idx,
1810
                        output_hidden_states,
1811
1812
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1813
1814
1815
1816
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
1817
1818
1819
1820
1821
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
Lianmin Zheng's avatar
Lianmin Zheng committed
1822
1823
1824
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1825

1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
1852
                    (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1872
            self.spec_algorithm,
1873
            self.server_args.enable_custom_logit_processor,
1874
1875
1876
1877
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
        if self.server_args.enable_dp_attention:
            if self.attn_tp_size > 1:
                # Sync across attn TP ranks to make sure they have the same number of ready requests
                tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
                torch.distributed.all_reduce(
                    tensor,
                    op=torch.distributed.ReduceOp.MAX,
                    group=self.attn_tp_cpu_group,
                )
                num_ready_reqs_max = tensor.item()
                for i in range(num_ready_reqs, num_ready_reqs_max):
                    self.grammar_queue[i].grammar = self.grammar_queue[
                        i
                    ].grammar.result()
                num_ready_reqs = num_ready_reqs_max
        else:
            if self.tp_size > 1:
                # Sync across TP ranks to make sure they have the same number of ready requests
                tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
                torch.distributed.all_reduce(
                    tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
                )
                num_ready_reqs_max = tensor.item()
                for i in range(num_ready_reqs, num_ready_reqs_max):
                    self.grammar_queue[i].grammar = self.grammar_queue[
                        i
                    ].grammar.result()
                num_ready_reqs = num_ready_reqs_max
1916

1917
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1918
1919
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1920
1921
1922
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1923
    def flush_cache(self):
1924
        """Flush the memory pool and cache."""
1925
1926
1927
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
1928
1929
            self.cur_batch = None
            self.last_batch = None
1930
1931
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1932
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1933
                self.grammar_backend.reset()
1934
1935
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
1936
1937
1938
1939
1940
1941
1942
1943
1944

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
                self.draft_worker.model_runner.token_to_kv_pool.clear()

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1945
1946
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
2012
2013
            logger.debug(f"Abort queued request. {req.rid=}")
            return
2014
2015
2016
2017

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
2018
                if req.rid == recv_req.rid and not req.finished():
2019
2020
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
2021
2022
                    break

Chayenne's avatar
Chayenne committed
2023
2024
2025
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2026
2027
2028
2029
2030
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2031
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2032

2033
2034
2035
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2036
        return InitWeightsUpdateGroupReqOutput(success, message)
2037
2038

    def update_weights_from_distributed(
2039
2040
2041
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2042
2043
2044
2045
2046
2047
2048
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2049
        return UpdateWeightsFromDistributedReqOutput(success, message)
2050

2051
2052
2053
2054
2055
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2056
2057
2058
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2059
2060
        else:
            logger.error(message)
2061
        return UpdateWeightsFromTensorReqOutput(success, message)
2062

2063
2064
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2065
        return GetWeightsByNameReqOutput(parameter)
2066

2067
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2068
2069
2070
2071
2072
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2073
        return ReleaseMemoryOccupationReqOutput()
2074

2075
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2076
2077
2078
2079
2080
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2081
2082
2083
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
2084
2085
2086
2087
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
                recv_req.output_dir, recv_req.num_steps, recv_req.activities
            )
2088
        else:
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
    ) -> None:
        if self.torch_profiler_activities:
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
        self.torch_profiler_activities = activities
        logger.info(
            "Profiling starts. Traces will be saved to: %s",
            self.torch_profiler_output_dir,
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
                with_stack=True,
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2132

2133
2134
2135
2136
2137
2138
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2139
2140

    def stop_profile(self) -> None:
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
        if self.torch_profiler_activities is None:
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
                    str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
                )
            )

        if "MEM" in self.torch_profiler_activities:
            memory_profile_path = os.path.join(
                self.torch_profiler_trace_dir,
                str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2165
        )
2166
2167
2168
2169
2170
2171
2172
2173
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
        self.torch_profiler_activities = None

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
2174

2175
    def open_session(self, recv_req: OpenSessionReqInput):
2176
2177
2178
2179
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2180
            return OpenSessionReqOutput(session_id, False)
2181
        elif session_id is None:
2182
            logger.warning("session id is None, cannot open.")
2183
            return OpenSessionReqOutput(session_id, False)
2184
2185
2186
2187
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2188
            return OpenSessionReqOutput(session_id, True)
2189
2190
2191
2192
2193
2194
2195
2196
2197

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2198

2199
2200
2201
2202
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2217
2218
2219
2220
2221
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2222
    dp_rank: Optional[int],
2223
    pipe_writer,
2224
):
2225
2226
2227
    # Config the process
    # kill_itself_when_parent_died()  # This is disabled because it does not work for `--dp 2`
    setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
2228
    faulthandler.enable()
2229
    parent_process = psutil.Process().parent()
2230

2231
2232
2233
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2234

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2235
    # Configure the logger
2236
    if dp_rank is None:
2237
        prefix = f" TP{tp_rank}"
2238
    else:
2239
2240
        prefix = f" DP{dp_rank} TP{tp_rank}"
    configure_logger(server_args, prefix=prefix)
2241
    suppress_other_loggers()
2242

2243
    # Set cpu affinity to this gpu process
2244
2245
2246
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2247
    # Create a scheduler and run the event loop
2248
    try:
2249
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2250
        pipe_writer.send(
Mick's avatar
Mick committed
2251
2252
2253
2254
2255
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2256
        )
2257
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
2258
2259
2260
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
2261
    except Exception:
2262
2263
2264
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)