"...lm-evaluation-harness.git" did not exist on "d23321491b5ccbdf500d445f9824837b5a6d50af"
scheduler.py 77.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
40
41
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
)
53
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
54
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
55
56
57
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
58
    CloseSessionReqInput,
59
    ExpertDistributionReq,
60
    ExpertDistributionReqOutput,
61
    FlushCacheReq,
62
63
    GetInternalStateReq,
    GetInternalStateReqOutput,
64
65
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
66
    HealthCheckOutput,
67
68
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
69
70
    OpenSessionReqInput,
    OpenSessionReqOutput,
71
    ProfileReq,
72
73
    ProfileReqOutput,
    ProfileReqType,
74
75
76
77
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
78
79
    RpcReqInput,
    RpcReqOutput,
80
81
    SetInternalStateReq,
    SetInternalStateReqOutput,
82
83
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
84
85
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
86
87
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
88
89
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
90
91
92
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
93
    MultimodalInputs,
94
95
    Req,
    ScheduleBatch,
96
    global_server_args_dict,
97
)
98
99
100
101
102
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
103
104
105
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
106
from sglang.srt.managers.session_controller import Session
107
from sglang.srt.managers.tp_worker import TpModelWorker
108
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
109
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
110
from sglang.srt.mem_cache.chunk_cache import ChunkCache
111
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
112
from sglang.srt.mem_cache.radix_cache import RadixCache
113
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
114
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
115
from sglang.srt.server_args import PortArgs, ServerArgs
116
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
117
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
118
from sglang.srt.utils import (
119
    DynamicGradMode,
120
121
    broadcast_pyobj,
    configure_logger,
122
    crash_on_warnings,
123
    get_bool_env_var,
124
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
125
    kill_itself_when_parent_died,
126
    pyspy_dump_schedulers,
127
    set_gpu_proc_affinity,
128
129
130
    set_random_seed,
    suppress_other_loggers,
)
131
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
132

133
134
expert_distribution_recorder = ExpertDistributionRecorder()

135
136
logger = logging.getLogger(__name__)

137
# Test retract decode for debugging purposes
138
139
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
140

141

142
143
144
145
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
146
147
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
148
149
150
151
152
153
154
155
156
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
157
158
159
160
161
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
162
163
164
165
166
167
168
169
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
170
        dp_rank: Optional[int],
171
172
    ):
        # Parse args
173
        self.server_args = server_args
174
175
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
176
177
178
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
179
        self.enable_overlap = not server_args.disable_overlap_schedule
180
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
181
        self.enable_metrics = server_args.enable_metrics
182
        self.stream_interval = server_args.stream_interval
183
184
185
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
186
187
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
188
        self.page_size = server_args.page_size
189

190
        # Distributed rank info
191
192
193
194
195
196
197
198
199
200
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

201
202
        # Init inter-process communication
        context = zmq.Context(2)
203
        if self.attn_tp_rank == 0:
204
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
205
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
206
            )
207
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
208
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
209
            )
210

211
            if server_args.skip_tokenizer_init:
212
                # Directly send to the TokenizerManager
213
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
214
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
215
216
                )
            else:
217
                # Send to the DetokenizerManager
218
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
219
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
220
                )
221
222
223
224

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
225
        else:
226
            self.recv_from_tokenizer = None
227
            self.recv_from_rpc = None
228
229
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
230
231

        # Init tokenizer
232
        self.init_tokenizer()
233

234
235
236
237
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
238
239
240
241
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

242
        # Launch a tensor parallel worker
243
        if self.enable_overlap:
244
            TpWorkerClass = TpModelWorkerClient
245
246
        else:
            TpWorkerClass = TpModelWorker
247

248
        self.tp_worker = TpWorkerClass(
249
            server_args=server_args,
250
251
            gpu_id=gpu_id,
            tp_rank=tp_rank,
252
            dp_rank=dp_rank,
253
            nccl_port=port_args.nccl_port,
254
        )
255

256
        # Launch a draft worker for speculative decoding
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

271
        # Get token and memory info from the model worker
272
273
274
275
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
276
            self.max_req_len,
277
278
            self.max_req_input_len,
            self.random_seed,
279
            self.device,
280
281
282
283
284
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
285
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
286
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
287
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
288
        global_server_args_dict.update(worker_global_server_args_dict)
289
        set_random_seed(self.random_seed)
290

291
292
293
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
294
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
295
296
297
298
299
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
300
        # Init memory pool and cache
301
        self.init_memory_pool_and_cache()
302
303
304

        # Init running status
        self.waiting_queue: List[Req] = []
305
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
306
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
307
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
308
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
309
        # The last forward batch
310
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
311
312
        self.forward_ct = 0
        self.forward_ct_decode = 0
313
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
314
        self.num_prefill_tokens = 0
315
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
316
        self.last_prefill_stats_tic = time.time()
317
        self.return_health_check_ct = 0
318
        self.current_stream = torch.get_device_module(self.device).current_stream()
319
320
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
321

322
        # Init session info
323
        self.sessions: Dict[str, Session] = {}
324
325
326

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
327
328
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
329
        self.chunked_req = None
330
331
332
333
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
334
        # Init the grammar backend for constrained generation
335
        self.grammar_queue: List[Req] = []
336
        if not server_args.skip_tokenizer_init:
337
338
339
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
        else:
            self.grammar_backend = None
342

343
        # Init schedule policy and new token estimation
344
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
348
        )
349
350
351
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
352
353
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
354
355
            * server_args.schedule_conservativeness,
            1.0,
356
        )
357
358
359
360
361
362
363
364
365
366
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
369
370
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
371
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
372

373
        # Init memory saver
374
375
376
377
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

378
        # Init profiler
379
380
381
382
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.torch_profiler_activities: Optional[List[str]] = None
        self.profiler_target_forward_ct: Optional[int] = None
383

384
        # Init metrics stats
385
        self.init_metrics()
386

387
388
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
389
390
391
392
393
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
394
395
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
396
397
398
399
400
401
402
403
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
404
405
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
406
                (ProfileReq, self.profile),
407
                (GetInternalStateReq, self.get_internal_state),
408
                (SetInternalStateReq, self.set_internal_state),
409
                (RpcReqInput, self.handle_rpc_request),
410
                (ExpertDistributionReq, self.expert_distribution_handle),
411
412
413
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
414
415
416
417
418
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

419
420
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
421

422
423
424
425
426
427
428
429
430
431
432
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )
        self.is_generation = self.model_config.is_generation
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
473
                    tp_cache_group=self.tp_worker.get_tp_cpu_group(),
474
                    page_size=self.page_size,
475
                    hicache_ratio=server_args.hicache_ratio,
476
477
478
479
480
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
481
                    page_size=self.page_size,
482
483
484
485
486
487
488
489
490
491
492
493
494
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
495
        )
496
497
498
499
500
501
502

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
503
        self.last_input_throughput: float = 0.0
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
518

Byron Hsu's avatar
Byron Hsu committed
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    def init_disaggregation(self):
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
                gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
                gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
            )
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
                gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
            )
            # The prefill requests that are in the middle of kv sending
            self.disagg_prefill_infight_queue: List[Req] = []

586
    @DynamicGradMode()
587
    def event_loop_normal(self):
588
        """A normal scheduler loop."""
589
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
590
591
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
592

593
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
594
            self.cur_batch = batch
595
596
597
598

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
599
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
600
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
601
                self.check_memory()
602
                self.new_token_ratio = self.init_new_token_ratio
603
604

            self.last_batch = batch
605

606
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
607
    def event_loop_overlap(self):
608
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
609
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
610
611
612
613
614
615
616

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
617

Lianmin Zheng's avatar
Lianmin Zheng committed
618
619
            if batch:
                result = self.run_batch(batch)
620
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
621

622
                if self.last_batch is None:
623
                    # Create a dummy first batch to start the pipeline for overlap schedule.
624
625
626
627
628
629
630
631
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
632
            if self.last_batch:
633
                # Process the results of the last batch
634
                tmp_batch, tmp_result = self.result_queue.popleft()
635
636
637
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
638
639
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
640
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
641
                self.check_memory()
642
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
645

            self.last_batch = batch

Byron Hsu's avatar
Byron Hsu committed
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
    @torch.no_grad()
    def event_loop_normal_disagg_prefill(self):
        """A normal scheduler loop for prefill worker in disaggregation mode."""

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
            self.waiting_queue.extend(
                self.disagg_prefill_pending_queue.pop_bootstrapped()
            )
            self.process_prefill_chunk()
            batch = self.get_new_batch_prefill()
            self.cur_batch = batch

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result_disagg_prefill(batch, result)

            if len(self.disagg_prefill_infight_queue) > 0:
                self.process_disagg_prefill_infight_queue()

            if batch is None and len(self.disagg_prefill_infight_queue) == 0:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

            self.last_batch = batch
            # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
            # Otherwise, it hangs under high concurrency
            self.running_batch.batch_is_full = False

    @torch.no_grad()
    def event_loop_normal_disagg_decode(self):
        """A normal scheduler loop for decode worker in disaggregation mode."""

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
            # polling and allocating kv cache
            self.process_decode_queue()
            batch = self.get_next_disagg_decode_batch_to_run()
            self.cur_batch = batch

            if batch:
                # Generate fake extend output.
                if batch.forward_mode.is_extend():
                    # Note: Logprobs should be handled on the prefill engine.
                    self.stream_output(
                        batch.reqs, [False for _ in range(len(batch.reqs))]
                    )
                else:
                    result = self.run_batch(batch)
                    self.process_batch_result(batch, result)

            if batch is None and (
                len(self.disagg_decode_transfer_queue.queue)
                + len(self.disagg_decode_prealloc_queue.queue)
                == 0
            ):
                # When the server is idle, do self-check and re-init some states
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

            self.last_batch = batch

710
711
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
712
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
713
714
            recv_reqs = []

715
716
717
718
719
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
720
                recv_reqs.append(recv_req)
721
722
723
724
725
726
727

            while True:
                try:
                    recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_rpc)
Lianmin Zheng's avatar
Lianmin Zheng committed
728
729
        else:
            recv_reqs = None
730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
765
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
766
767
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
768
    def process_input_requests(self, recv_reqs: List):
769
        for recv_req in recv_reqs:
770
771
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
772
                self.chunked_req is not None or not self.running_batch.is_empty()
773
774
775
776
            ):
                self.return_health_check_ct += 1
                continue

777
            output = self._request_dispatcher(recv_req)
778
            if output is not None:
779
780
781
782
783
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
784
785
786
787
788

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
789
        # Create a new request
790
791
792
793
794
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
795
796
797
798
799
800
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

801
802
803
804
805
806
807
808
809
810
811
812
813
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

814
815
816
817
818
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
819
820
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
821
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
822
                stream=recv_req.stream,
823
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
824
                input_embeds=recv_req.input_embeds,
825
                custom_logit_processor=custom_logit_processor,
826
                return_hidden_states=recv_req.return_hidden_states,
827
                eos_token_ids=self.model_config.hf_eos_token_id,
828
829
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
830

831
832
833
834
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
835
                req.finished_reason = FINISH_ABORT(
836
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
837
                )
838
                self._add_request_to_queue(req)
839
840
                return
        else:
841
842
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
843
            req = session.create_req(recv_req, self.tokenizer)
844
            if isinstance(req.finished_reason, FINISH_ABORT):
845
                self._add_request_to_queue(req)
846
                return
847

848
        # Handle multimodal inputs
Mick's avatar
Mick committed
849
850
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
851
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
852
            req.origin_input_ids = self.pad_input_ids_func(
853
                req.origin_input_ids, image_inputs
854
            )
855
            req.extend_image_inputs(image_inputs)
856

857
            if len(req.origin_input_ids) >= self.max_req_input_len:
858
                error_msg = (
859
                    "Multimodal prompt is too long after expanding multimodal tokens. "
860
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
861
                )
862
                logger.error(error_msg)
863
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
864
                req.multimodal_inputs = None
865
                req.sampling_params.max_new_tokens = 0
866
                req.finished_reason = FINISH_ABORT(
867
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
868
                )
869
                self._add_request_to_queue(req)
870
871
                return

872
873
874
875
876
877
878
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
879
880
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
881
            self._add_request_to_queue(req)
882
            return
883

884
        # Copy more attributes
885
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
886
887
888
889
890
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

891
892
893
894
895
896
897
898
899
900
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

901
902
903
904
905
906
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
907
            self.max_req_len - len(req.origin_input_ids) - 1,
908
909
        )

910
911
912
913
914
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
915
            or req.sampling_params.ebnf is not None
916
            or req.sampling_params.structural_tag is not None
917
918
919
920
921
922
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
923
924
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
925
926
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
927
928
929
930
931
932
933

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
934
935
            self.grammar_queue.append(req)
        else:
936
937
938
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
Byron Hsu's avatar
Byron Hsu committed
939
940
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_pending_queue.add(req)
941

Byron Hsu's avatar
Byron Hsu committed
942
943
944
945
946
947
948
949
950
951
952
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)

        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
953
954
955

    def handle_embedding_request(
        self,
956
        recv_req: TokenizedEmbeddingReqInput,
957
958
959
960
961
962
963
964
965
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

966
967
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
968
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
969
970
971
972
973
974
975
976
977
978
979
980
981
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
982
                req.multimodal_inputs = None
983
984
985
986
987
988
989
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
                self.waiting_queue.append(req)
                return

990
        # Validate prompts length
991
        error_msg = validate_input_length(
992
993
994
995
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
996
        if error_msg:
997
            self._add_request_to_queue(req)
998
            return
999

1000
1001
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1002
        self._add_request_to_queue(req)
1003

1004
1005
1006
1007
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1008
        running_bs: int,
1009
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1010
1011
1012
1013
1014
        gap_latency = time.time() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.time()
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1015
        num_used = self.max_total_num_tokens - (
1016
1017
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1018
        )
1019
1020
1021
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
1022

1023
        f = (
1024
1025
1026
1027
1028
1029
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
1030
            f"#queue-req: {len(self.waiting_queue)}, "
1031
        )
1032
        logger.info(f)
1033
1034

        if self.enable_metrics:
1035
1036
1037
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1038
1039
1040
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1041
1042
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1043
1044
1045
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
1046
1047
1048
1049
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1050
        num_running_reqs = len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1051
        num_used = self.max_total_num_tokens - (
1052
1053
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1054
        )
1055
1056
1057
1058
1059

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1060

1061
1062
1063
1064
1065
1066
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1067
1068
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}, "
1069
            )
1070
            spec_accept_length = 0
1071
        else:
1072
            spec_accept_length = (
1073
1074
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1075
1076
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1077
1078
1079
1080
1081
1082
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1083
                f"accept len: {spec_accept_length:.2f}, "
1084
1085
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}, "
1086
1087
1088
            )

        logger.info(msg)
1089
1090
1091
1092
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1093
1094
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1095
            self.stats.num_queue_reqs = len(self.waiting_queue)
1096
            self.stats.spec_accept_length = spec_accept_length
1097
1098
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
1099
1100
    def check_memory(self):
        available_size = (
1101
1102
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1103
        )
1104
1105
1106
1107
1108
1109
1110
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1111
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
1112
                "KV cache pool leak detected! "
1113
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1114
1115
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1116
            )
1117
1118
1119
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1120
1121

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1122
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
1123
                "Memory pool leak detected!"
1124
1125
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
            )
1127
1128
1129
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1130

1131
1132
1133
1134
1135
1136
1137
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1138
                self.token_to_kv_pool_allocator.available_size()
1139
1140
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1141
            num_running_reqs = len(self.running_batch.reqs)
1142
1143
1144
1145
1146
1147
1148
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

1149
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1150
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1151
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1152
1153
1154
1155
1156
1157
1158
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1159
                self.running_batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1160

1161
            # Filter batch
1162
            last_bs = self.last_batch.batch_size()
1163
            self.last_batch.filter_batch()
1164
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1165
                self.running_batch.batch_is_full = False
1166

1167
            # Merge the new batch into the running batch
1168
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1169
                if self.running_batch.is_empty():
1170
1171
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1172
                    # Merge running_batch with prefill batch
1173
                    self.running_batch.merge_batch(self.last_batch)
1174

1175
1176
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1177
1178
1179
1180
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1181
            if not self.running_batch.is_empty():
1182
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1183
1184
1185
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1186

1187
1188
        # Handle DP attention
        if self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
1189
            ret, _ = self.prepare_dp_attn_batch(ret)
1190
1191

        return ret
1192

Lianmin Zheng's avatar
Lianmin Zheng committed
1193
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1194
        # Check if the grammar is ready in the grammar queue
1195
        if self.grammar_queue:
1196
            self.move_ready_grammar_requests()
1197

Lianmin Zheng's avatar
Lianmin Zheng committed
1198
1199
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1200
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1201
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1202
1203
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1204
        running_bs = len(self.running_batch.reqs)
1205
        if running_bs >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
            self.running_batch.batch_is_full = True
1207
1208
            return None

1209
1210
1211
1212
1213
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1214
1215
1216
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1217
        # Prefill policy
1218
1219
        adder = PrefillAdder(
            self.tree_cache,
1220
            self.token_to_kv_pool_allocator,
1221
1222
1223
1224
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1225
            running_bs if self.is_mixed_chunk else 0,
1226
1227
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1228
        if self.chunked_req is not None:
1229
1230
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1231

Lianmin Zheng's avatar
Lianmin Zheng committed
1232
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1233
1234
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1235
        # Get requests from the waiting queue to a new prefill batch
1236
1237
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1238
                self.lora_paths
1239
1240
1241
1242
1243
1244
1245
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1246
                self.running_batch.batch_is_full = True
1247
1248
                break

1249
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
                self.running_batch.batch_is_full = True
1251
                break
1252

1253
1254
1255
1256
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1257

1258
1259
1260
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1261
1262
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1263
1264
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1265
1266
1267
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
                        ) > 0 or (
1268
1269
1270
1271
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1272
                        self.running_batch.batch_is_full = True
1273
1274
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1275
        # Update waiting queue
1276
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1277
1278
1279
1280
1281
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1282

1283
1284
1285
        if self.enable_hierarchical_cache:
            self.tree_cache.read_to_load_cache()

1286
1287
1288
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1289

1290
1291
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1292

1293
        # Print stats
1294
        if self.attn_tp_rank == 0:
1295
            self.log_prefill_stats(adder, can_run_list, running_bs)
1296

Lianmin Zheng's avatar
Lianmin Zheng committed
1297
        # Create a new batch
1298
1299
1300
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1301
            self.token_to_kv_pool_allocator,
1302
            self.tree_cache,
1303
            self.model_config,
1304
            self.enable_overlap,
1305
            self.spec_algorithm,
1306
            self.server_args.enable_custom_logit_processor,
1307
        )
1308
        new_batch.prepare_for_extend()
1309

Lianmin Zheng's avatar
Lianmin Zheng committed
1310
        # Mixed-style chunked prefill
1311
1312
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1313
            and not self.running_batch.is_empty()
1314
1315
1316
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1317
1318
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1319
                self.running_batch.prepare_for_decode()
1320
1321
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1322
1323
1324
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1325
1326
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
1328
1329

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1330
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1331
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1332
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1333

1334
1335
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1336
1337
            batch.batch_is_full = False
            return batch
1338

Lianmin Zheng's avatar
Lianmin Zheng committed
1339
        # Check if decode out of memory
1340
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1341
            TEST_RETRACT and batch.batch_size() > 10
1342
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1343
1344
            old_ratio = self.new_token_ratio

1345
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1346
            self.new_token_ratio = new_token_ratio
1347

Lianmin Zheng's avatar
Lianmin Zheng committed
1348
1349
1350
1351
1352
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1353
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1354
1355
        else:
            self.new_token_ratio = max(
1356
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1357
1358
1359
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1360
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1361
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1362
1363

        # Update batch tensors
1364
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1365
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1366

1367
1368
1369
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1370
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1371
1372
        self.forward_ct += 1

1373
1374
1375
1376
1377
1378
1379
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1380
        # Run forward
1381
        if self.is_generation:
1382
1383
1384
1385
1386
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
1387
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1388
            else:
1389
1390
1391
                (
                    logits_output,
                    next_token_ids,
1392
                    bid,
1393
1394
1395
1396
1397
1398
1399
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1400
            batch.output_ids = next_token_ids
1401

1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1414
1415
1416
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1417
1418
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1419
                bid=bid,
1420
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1421
1422
1423
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1424
1425
1426
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1427
        return ret
Chayenne's avatar
Chayenne committed
1428

1429
1430
1431
1432
1433
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1434
1435
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1436
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1437
            self.process_batch_result_prefill(batch, result)
1438
1439
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1440
                self.tp_worker.resolve_batch_result(result.bid)
1441
1442
1443
1444
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1445
1446
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1447
            self.current_stream.synchronize()
1448
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1449

1450
1451
1452
1453
1454
1455
1456
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1457
1458
1459
1460
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1461
            global_num_tokens_for_logprob = 0
1462
1463
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1464
1465
1466
            if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
                num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
            global_num_tokens_for_logprob = num_tokens
1467
1468
        else:
            num_tokens = local_batch.extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
            global_num_tokens_for_logprob = sum(
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        if not self.spec_algorithm.is_none():
            # TODO(sang): Support cuda graph when idle batch is there.
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1488

Lianmin Zheng's avatar
Lianmin Zheng committed
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
                global_num_tokens_for_logprob,
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
            (self.server_args.dp_size, self.attn_tp_size, 4),
            dtype=torch.int64,
        )
1505
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1506
1507
            global_info.flatten(),
            local_info,
1508
1509
            group=self.tp_cpu_group,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1510
1511
1512
1513
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1514

Lianmin Zheng's avatar
Lianmin Zheng committed
1515
        if local_batch is None and max(global_num_tokens) > 0:
1516
1517
1518
            local_batch = self.get_idle_batch()

        if local_batch is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1519
1520
            local_batch.global_num_tokens = global_num_tokens
            local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1521
1522
1523

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1524
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1525

Lianmin Zheng's avatar
Lianmin Zheng committed
1526
        return local_batch, any(is_extend_in_batch)
1527
1528
1529
1530
1531

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1532
            self.token_to_kv_pool_allocator,
1533
1534
1535
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1536
            self.spec_algorithm,
1537
            self.server_args.enable_custom_logit_processor,
1538
1539
1540
1541
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1552
        if self.server_args.enable_dp_attention:
1553
1554
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1555
        else:
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1569

1570
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1571
1572
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool_allocator.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1604
1605
1606
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1607
    def flush_cache(self):
1608
        """Flush the memory pool and cache."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1609
        if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1610
1611
            self.cur_batch = None
            self.last_batch = None
1612
            self.tree_cache.reset()
1613
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1614
                self.grammar_backend.reset()
1615
            self.req_to_token_pool.clear()
1616
            self.token_to_kv_pool_allocator.clear()
1617
1618
1619

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1620
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1621
1622
1623
1624
1625

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1626
1627
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1628
1629
1630
1631
1632
1633
1634
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1635
                f"#running-req: {len(self.running_batch.reqs)}"
1636
1637
1638
1639
            )
            if_success = False
        return if_success

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

        if isinstance(self.tp_worker, TpModelWorkerClient):
            worker = self.tp_worker.worker
        else:
            worker = self.tp_worker

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
        if isinstance(self.tp_worker, TpModelWorkerClient):
            worker = self.tp_worker.worker
        else:
            worker = self.tp_worker

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1724
1725
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1726
        to_del = []
1727
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1730
1731
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1732
1733
1734
        # Sort in reverse order to avoid index issues when deleting
        for i in sorted(to_del, reverse=True):
            req = self.waiting_queue.pop(i)
1735
1736
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1737
1738

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1739
1740
1741
1742
1743
        for req in self.running_batch.reqs:
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
                return
1744

1745
1746
1747
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
1748
1749
1750
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1751
1752
1753
1754
1755
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1756
        return UpdateWeightFromDiskReqOutput(success, message, 0)
1757

1758
1759
1760
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1761
        return InitWeightsUpdateGroupReqOutput(success, message)
1762
1763

    def update_weights_from_distributed(
1764
1765
1766
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1767
1768
1769
1770
1771
1772
1773
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1774
        return UpdateWeightsFromDistributedReqOutput(success, message)
1775

1776
1777
1778
1779
1780
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
1781
1782
1783
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
1784
1785
        else:
            logger.error(message)
1786
        return UpdateWeightsFromTensorReqOutput(success, message)
1787

1788
1789
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1790
        return GetWeightsByNameReqOutput(parameter)
1791

1792
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1793
1794
1795
1796
1797
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
1798
        return ReleaseMemoryOccupationReqOutput()
1799

1800
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1801
1802
1803
1804
1805
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
1806
1807
1808
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
1809
1810
1811
1812
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
                recv_req.output_dir, recv_req.num_steps, recv_req.activities
            )
1813
        else:
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
    ) -> None:
        if self.torch_profiler_activities:
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
        self.torch_profiler_activities = activities
        logger.info(
            "Profiling starts. Traces will be saved to: %s",
            self.torch_profiler_output_dir,
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
                with_stack=True,
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
1857

1858
1859
1860
1861
1862
1863
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
1864
1865

    def stop_profile(self) -> None:
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
        if self.torch_profiler_activities is None:
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
                    str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
                )
            )

        if "MEM" in self.torch_profiler_activities:
            memory_profile_path = os.path.join(
                self.torch_profiler_trace_dir,
                str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
1890
        )
1891
1892
1893
1894
1895
1896
1897
1898
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
        self.torch_profiler_activities = None

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
1899

1900
1901
1902
1903
1904
1905
1906
1907
1908
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
            expert_distribution_recorder.start_record()
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
            expert_distribution_recorder.stop_record()
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
            expert_distribution_recorder.dump_record()
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
1909
        return ExpertDistributionReqOutput()
1910

1911
    def open_session(self, recv_req: OpenSessionReqInput):
1912
1913
1914
1915
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1916
            return OpenSessionReqOutput(session_id, False)
1917
        elif session_id is None:
1918
            logger.warning("session id is None, cannot open.")
1919
            return OpenSessionReqOutput(session_id, False)
1920
1921
1922
1923
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1924
            return OpenSessionReqOutput(session_id, True)
1925
1926
1927
1928
1929
1930
1931
1932
1933

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1934

1935
1936
1937
1938
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1953
1954
1955
1956
1957
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1958
    dp_rank: Optional[int],
1959
    pipe_writer,
1960
):
1961
1962
1963
1964
1965
1966
1967

    # Generate the prefix
    if dp_rank is None:
        prefix = f" TP{tp_rank}"
    else:
        prefix = f" DP{dp_rank} TP{tp_rank}"

1968
    # Config the process
1969
    kill_itself_when_parent_died()
1970
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1971
    faulthandler.enable()
1972
    parent_process = psutil.Process().parent()
1973

1974
1975
1976
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1977

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1978
    # Configure the logger
1979
    configure_logger(server_args, prefix=prefix)
1980
    suppress_other_loggers()
1981

1982
    # Set cpu affinity to this gpu process
1983
1984
1985
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1986
    # Create a scheduler and run the event loop
1987
    try:
1988
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1989
        pipe_writer.send(
Mick's avatar
Mick committed
1990
1991
1992
1993
1994
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
1995
        )
Byron Hsu's avatar
Byron Hsu committed
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
            scheduler.event_loop_normal_disagg_prefill()
        elif disaggregation_mode == DisaggregationMode.DECODE:
            scheduler.event_loop_normal_disagg_decode()

2008
    except Exception:
2009
2010
2011
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)