scheduler.py 67.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import threading
21
22
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
34
import zmq

35
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
36
from sglang.srt.configs.model_config import ModelConfig
37
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
38
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
39
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
40
41
42
43
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
45
    CloseSessionReqInput,
46
    FlushCacheReq,
47
48
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
49
50
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
51
52
    OpenSessionReqInput,
    OpenSessionReqOutput,
53
    ProfileReq,
54
55
56
57
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
58
59
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
60
61
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
62
63
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
64
65
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
66
67
68
69
70
71
72
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
73
    global_server_args_dict,
74
)
75
76
77
78
79
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
80
from sglang.srt.managers.session_controller import Session
81
from sglang.srt.managers.tp_worker import TpModelWorker
82
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
from sglang.srt.managers.utils import validate_input_length
84
85
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
86
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
87
from sglang.srt.model_executor.forward_batch_info import ForwardMode
88
from sglang.srt.server_args import PortArgs, ServerArgs
89
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
90
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
91
92
93
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
94
    crash_on_warnings,
95
    get_bool_env_var,
96
    get_zmq_socket,
97
    set_gpu_proc_affinity,
98
99
100
    set_random_seed,
    suppress_other_loggers,
)
101
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
102
103
104

logger = logging.getLogger(__name__)

105
# Test retract decode for debugging purposes
106
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
107

108

109
110
111
112
113
114
115
116
117
118
119
120
121
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


122
123
124
125
126
127
128
129
130
class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
131
        dp_rank: Optional[int],
132
133
    ):
        # Parse args
134
        self.server_args = server_args
135
136
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
137
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        self.disable_jump_forward = server_args.disable_jump_forward
139
140
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
141
        self.enable_overlap = not server_args.disable_overlap_schedule
142
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
143
        self.enable_metrics = server_args.enable_metrics
144
145
146
147
148
149
150
151
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
        self.decode_mem_cache_buf_multiplier = (
            self.server_args.speculative_num_draft_tokens
            if not self.spec_algorithm.is_none()
            else 1
        )
152

153
        # Distributed rank info
154
155
156
157
158
159
160
161
162
163
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

164
165
        # Init inter-process communication
        context = zmq.Context(2)
166
        if self.attn_tp_rank == 0:
167
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
168
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
169
            )
170
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
171
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
172
            )
173

174
            if server_args.skip_tokenizer_init:
175
                # Directly send to the TokenizerManager
176
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
177
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
178
179
                )
            else:
180
                # Send to the DetokenizerManager
181
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
182
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
183
                )
184
        else:
185
            self.recv_from_tokenizer = None
186
187
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
188
189
190
191

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
192
            trust_remote_code=server_args.trust_remote_code,
193
            revision=server_args.revision,
194
            context_length=server_args.context_length,
195
196
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
197
198
            dtype=server_args.dtype,
            quantization=server_args.quantization,
199
        )
200
        self.is_generation = self.model_config.is_generation
201
202
203
204

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
205
            if self.model_config.is_multimodal:
206
207
208
209
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
210
                    revision=server_args.revision,
211
212
213
214
215
216
217
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
218
                    revision=server_args.revision,
219
                )
220

221
222
223
224
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
225

226
227
228
229
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

230
231
        if self.enable_overlap:
            self.disable_jump_forward = True
232

233
        # Launch a tensor parallel worker
234
        if self.enable_overlap:
235
            TpWorkerClass = TpModelWorkerClient
236
237
        else:
            TpWorkerClass = TpModelWorker
238

239
        self.tp_worker = TpWorkerClass(
240
            server_args=server_args,
241
242
            gpu_id=gpu_id,
            tp_rank=tp_rank,
243
            dp_rank=dp_rank,
244
            nccl_port=port_args.nccl_port,
245
        )
246

247
        # Launch a worker for speculative decoding if needed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

262
        # Get token and memory info from the model worker
263
264
265
266
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
267
            self.max_req_len,
268
269
            self.max_req_input_len,
            self.random_seed,
270
            self.device,
271
272
273
274
275
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
276
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
277
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
278
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
279
        global_server_args_dict.update(worker_global_server_args_dict)
280
281
282
283
284
285
286
287
288
        set_random_seed(self.random_seed)
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

289
290
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
307
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
308
309
310

        # Init running status
        self.waiting_queue: List[Req] = []
311
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
312
        self.running_batch: Optional[ScheduleBatch] = None
313
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
314
        self.cur_batch: Optional[ScheduleBatch] = None
315
316
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
        self.forward_ct = 0
        self.forward_ct_decode = 0
319
        self.num_generated_tokens = 0
320
321
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
322
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
323
        self.stream_interval = server_args.stream_interval
324
        self.current_stream = torch.get_device_module(self.device).current_stream()
325
326
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
327
328

        # Session info
329
        self.sessions: Dict[str, Session] = {}
330
331
332

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
333
334
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
335
        self.being_chunked_req = None
336
337
338
339
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
340
        # Init the grammar backend for constrained generation
341
        self.grammar_queue: List[Req] = []
342
        if not server_args.skip_tokenizer_init:
343
344
345
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
        else:
            self.grammar_backend = None
348
349

        # Init new token estimation
350
351
352
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
353
354
355

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
356
357
            * server_args.schedule_conservativeness,
            1.0,
358
        )
359
360
361
362
363
364
365
366
367
368
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
371
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
372
        self.batch_is_full = False
373

Lianmin Zheng's avatar
Lianmin Zheng committed
374
375
376
377
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
378
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
379

380
381
382
383
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

384
        # Init profiler
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
400

401
        # Init metrics stats
402
403
404
405
406
407
408
409
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
410

411
412
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
                (ProfileReq, self.profile),
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
                (
                    ReleaseMemoryOccupationReqInput,
                    lambda _: self.release_memory_occupation(),
                ),
                (
                    ResumeMemoryOccupationReqInput,
                    lambda _: self.resume_memory_occupation(),
                ),
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
440
    def watchdog_thread(self):
441
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
444
445
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
446
            current = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
447
448
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
449
                    if current > self.watchdog_last_time + self.watchdog_timeout:
Lianmin Zheng's avatar
Lianmin Zheng committed
450
451
452
453
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
454
455
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)
456
457
        # Wait sometimes so that the parent process can print the error.
        time.sleep(5)
458
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
459

460
    @torch.no_grad()
461
    def event_loop_normal(self):
462
        """A normal scheduler loop."""
463
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
466

467
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
468
            self.cur_batch = batch
469
470
471
472

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
473
            else:
474
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
475
                self.check_memory()
476
                self.new_token_ratio = self.init_new_token_ratio
477
478

            self.last_batch = batch
479

480
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
481
    def event_loop_overlap(self):
482
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
483
484
485
486
487
488
489
490
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
491

Lianmin Zheng's avatar
Lianmin Zheng committed
492
493
494
495
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

496
                if self.last_batch is None:
497
                    # Create a dummy first batch to start the pipeline for overlap schedule.
498
499
500
501
502
503
504
505
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
506
            if self.last_batch:
507
                # Process the results of the last batch
Lianmin Zheng's avatar
Lianmin Zheng committed
508
                tmp_batch, tmp_result = result_queue.popleft()
509
510
511
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
514
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
515
                self.check_memory()
516
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
517
518
519

            self.last_batch = batch

520
521
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
522
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
523
524
            recv_reqs = []

525
526
527
528
529
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
530
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
        else:
            recv_reqs = None
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
568
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
569
570
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
571
    def process_input_requests(self, recv_reqs: List):
572
        for recv_req in recv_reqs:
573
            output = self._request_dispatcher(recv_req)
574
575
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
576
577
578
579
580

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
581
        # Create a new request
582
583
584
585
586
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
587

Rin Intachuen's avatar
Rin Intachuen committed
588
589
590
591
592
593
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

594
595
596
597
598
599
600
601
602
603
604
605
606
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

607
608
609
610
611
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
614
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
                stream=recv_req.stream,
615
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
616
                input_embeds=recv_req.input_embeds,
617
                custom_logit_processor=custom_logit_processor,
618
                eos_token_ids=self.model_config.hf_eos_token_id,
619
620
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
621

622
623
624
625
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
626
                req.finished_reason = FINISH_ABORT(
627
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
628
629
630
631
                )
                self.waiting_queue.append(req)
                return
        else:
632
633
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
634
            req = session.create_req(recv_req, self.tokenizer)
635
636
637
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
638

639
        # Handle image inputs
640
        if recv_req.image_inputs is not None:
641
642
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
643
            req.origin_input_ids = self.pad_input_ids_func(
644
                req.origin_input_ids, image_inputs
645
            )
646
            req.extend_image_inputs(image_inputs)
647

648
            if len(req.origin_input_ids) >= self.max_req_input_len:
649
                error_msg = (
650
                    "Multimodal prompt is too long after expanding multimodal tokens. "
651
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
652
                )
653
                logger.error(error_msg)
654
                req.origin_input_ids = [0]
655
                req.image_inputs = None
656
                req.sampling_params.max_new_tokens = 0
657
                req.finished_reason = FINISH_ABORT(
658
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
659
                )
660
661
662
                self.waiting_queue.append(req)
                return

663
        # Copy more attributes
664
665
666
667
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
668
            req.logprob_start_len = len(req.origin_input_ids) - 1
669

670
671
672
673
674
675
676
677
678
679
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )

        if error_msg:
            self.waiting_queue.append(req)
            return
680

681
682
683
684
685
686
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
687
            self.max_req_len - len(req.origin_input_ids) - 1,
688
689
        )

690
691
692
693
694
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
695
            or req.sampling_params.ebnf is not None
696
697
698
699
700
701
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
702
703
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
704
705
706
707
708
709
710

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
711
712
713
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
714
715
716

    def handle_embedding_request(
        self,
717
        recv_req: TokenizedEmbeddingReqInput,
718
719
720
721
722
723
724
725
726
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

727
728
729
730
731
732
        # Validate prompts length
        validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
733
734
735

        self.waiting_queue.append(req)

736
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
737
738
739
740
741
742
743
        self.tree_cache_metrics["total"] += (
            adder.log_input_tokens + adder.log_hit_tokens
        ) / 10**9
        self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
        tree_cache_hit_rate = (
            self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
        )
744
745
746
747
748
749
750
751
752
753
754
755
756

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
757
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
758
759
760
761
762
763
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
764
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
765
766
767
768
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
771
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
772
773
774
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
775
        self.num_generated_tokens = 0
776
        self.last_decode_stats_tic = time.time()
777
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
778

779
780
781
782
783
784
785
786
787
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
                f"gen throughput (token/s): {gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}"
            )
788
            spec_accept_length = 0
789
        else:
790
            spec_accept_length = (
791
792
793
794
795
796
797
798
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
799
                f"accept len: {spec_accept_length:.2f}, "
800
801
802
803
804
                f"gen throughput (token/s): {gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}"
            )

        logger.info(msg)
805
806
807
808
809
810
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
811
            self.stats.spec_accept_length = spec_accept_length
812
813
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
814
815
816
817
818
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
819
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
820
                "KV cache pool leak detected!"
821
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
822
            )
823
824
825
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
826
827

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
828
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
829
                "Memory pool leak detected!"
830
831
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
832
            )
833
834
835
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
836

837
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
838
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
839
        if self.last_batch and self.last_batch.forward_mode.is_extend():
840
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
841
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
842
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
843
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
844
                # being chunked request keeps its rid but will get a new req_pool_idx
845
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
846
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
847

848
849
850
851
852
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
853

854
855
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
856
857
858
859
860
861
862
863
864
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
            if self.running_batch is None:
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch
865

866
867
868
869
870
        # Handle DP attention
        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret
871

Lianmin Zheng's avatar
Lianmin Zheng committed
872
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
873
        # Check if the grammar is ready in the grammar queue
874
        if self.grammar_queue:
875
            self.move_ready_grammar_requests()
876

Lianmin Zheng's avatar
Lianmin Zheng committed
877
878
879
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
880
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
881
882
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
883
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
884
        if running_bs >= self.max_running_requests:
885
            self.batch_is_full = True
886
887
888
889
890
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
891
        # Prefill policy
892
893
        adder = PrefillAdder(
            self.tree_cache,
894
            self.token_to_kv_pool,
895
896
897
898
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
899
            running_bs if self.is_mixed_chunk else 0,
900
901
        )

902
903
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
904
            self.being_chunked_req.init_next_round_input()
905
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
906

Lianmin Zheng's avatar
Lianmin Zheng committed
907
        if self.lora_paths:
908
909
910
911
912
913
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

914
        # Get requests from the waiting queue to a new prefill batch
915
916
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
917
                self.lora_paths
918
919
920
921
922
923
924
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
925
                self.batch_is_full = True
926
927
                break

928
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
929
                self.batch_is_full = True
930
                break
931

932
933
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
934
935
936
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
937
                break
938
939
            if self.server_args.prefill_only_one_req:
                break
940

Lianmin Zheng's avatar
Lianmin Zheng committed
941
        # Update waiting queue
942
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
943
944
945
946
947
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
948

949
        if adder.new_being_chunked_req is not None:
950
            assert self.being_chunked_req is None
951
            self.being_chunked_req = adder.new_being_chunked_req
952

953
954
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
955

956
        # Print stats
957
        if self.attn_tp_rank == 0:
958
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
959

Lianmin Zheng's avatar
Lianmin Zheng committed
960
        # Create a new batch
961
962
963
964
965
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
966
            self.model_config,
967
            self.enable_overlap,
968
            self.spec_algorithm,
969
        )
970
        new_batch.prepare_for_extend()
971

Lianmin Zheng's avatar
Lianmin Zheng committed
972
        # Mixed-style chunked prefill
973
974
975
976
977
978
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
979
980
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
981
                self.running_batch.prepare_for_decode()
982
983
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
984
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
985
986
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
987
988
989

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
990
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
991
        """Update the current running decoding batch."""
992
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
993
994

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
995

996
997
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
998
999
            self.batch_is_full = False
            return None
1000

Lianmin Zheng's avatar
Lianmin Zheng committed
1001
        # Check if decode out of memory
1002
1003
1004
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
            test_retract and batch.batch_size() > 10
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1005
1006
1007
1008
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
1009
1010
            if self.draft_worker:
                self.draft_worker.finish_request(retracted_reqs)
1011

Lianmin Zheng's avatar
Lianmin Zheng committed
1012
1013
1014
1015
1016
1017
1018
1019
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
1020
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
1022
1023
1024
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
1025
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
1026
1027
1028
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1029
1030
1031
1032
1033
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1034
1035

        # Update batch tensors
1036
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1037
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1038

1039
1040
1041
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1042
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1043
1044
        self.forward_ct += 1

1045
        if self.is_generation:
1046
            if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
1047
1048
1049
1050
1051
1052
                if self.spec_algorithm.is_none():
                    model_worker_batch = batch.get_model_worker_batch()
                    logits_output, next_token_ids = (
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1053
1054
1055
1056
1057
1058
                    (
                        logits_output,
                        next_token_ids,
                        model_worker_batch,
                        num_accepted_tokens,
                    ) = self.draft_worker.forward_batch_speculative_generation(batch)
1059
1060
1061
1062
                    self.spec_num_total_accepted_tokens += (
                        num_accepted_tokens + batch.batch_size()
                    )
                    self.spec_num_total_forward_ct += batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1063
                    self.num_generated_tokens += num_accepted_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1064
            else:
1065
                assert False, "batch.extend_num_tokens == 0, this is unexpected!"
1066
            batch.output_ids = next_token_ids
1067
1068
1069
1070
1071
1072

            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
                bid=model_worker_batch.bid,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1073
1074
1075
1076
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1077
1078
1079
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1080
        return ret
Chayenne's avatar
Chayenne committed
1081

1082
1083
1084
1085
1086
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1087
1088
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1089
1090
            if batch.is_empty():
                self.running_batch = None
1091
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
            self.process_batch_result_prefill(batch, result)
1093
1094
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1095
                self.tp_worker.resolve_batch_result(result.bid)
1096
1097
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1098
            self.current_stream.synchronize()
1099
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1100

1101
1102
1103
1104
1105
    def process_batch_result_prefill(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
1106
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1107

Lianmin Zheng's avatar
Lianmin Zheng committed
1108
        if self.is_generation:
1109
1110
1111
1112
1113
1114
1115
1116
1117
            (
                logits_output,
                next_token_ids,
                bid,
            ) = (
                result.logits_output,
                result.next_token_ids,
                result.bid,
            )
1118
1119

            if self.enable_overlap:
1120
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1121
1122
            else:
                # Move next_token_ids and logprobs to cpu
1123
                next_token_ids = next_token_ids.tolist()
1124
                if batch.return_logprob:
1125
                    logits_output.next_token_logprobs = (
1126
                        logits_output.next_token_logprobs.tolist()
1127
1128
1129
1130
1131
1132
1133
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )

            # Check finish conditions
            logprob_pt = 0
1134
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1135
1136
1137
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1138
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1139
1140
1141
1142
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1143

1144
                if req.is_being_chunked <= 0:
1145
                    req.output_ids.append(next_token_id)
1146
1147
                    req.check_finished()

1148
                    if req.finished():
1149
                        self.tree_cache.cache_finished_req(req)
1150
1151
1152
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1153
1154
1155
1156
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1157
1158
1159

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1160
                        req.grammar.finished = req.finished()
1161
                else:
1162
                    # being chunked reqs' prefill is not finished
1163
                    req.is_being_chunked -= 1
1164
1165
1166
1167
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1168

1169
1170
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1171
                self.current_stream.synchronize()
1172
1173
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1174
        else:  # embedding or reward model
1175
            embeddings, bid = result.embeddings, result.bid
1176
            embeddings = embeddings.tolist()
1177
1178
1179

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1180
1181
1182
                if req.is_retracted:
                    continue

1183
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1184
1185
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1186
1187
1188
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1189
1190
1191
1192
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1193
                else:
1194
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1195
                    req.is_being_chunked -= 1
1196

Lianmin Zheng's avatar
Lianmin Zheng committed
1197
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1198

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    def process_batch_result_decode(
        self,
        batch: ScheduleBatch,
        result: GenerationBatchResult,
    ):
        logits_output, next_token_ids, bid = (
            result.logits_output,
            result.next_token_ids,
            result.bid,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1209
1210
        self.num_generated_tokens += len(batch.reqs)

1211
        if self.enable_overlap:
1212
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1213
            next_token_logprobs = logits_output.next_token_logprobs
1214
1215
        else:
            next_token_ids = next_token_ids.tolist()
1216
1217
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1218

1219
1220
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1221
1222
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1223
1224
1225
            if req.is_retracted:
                continue

1226
            if self.enable_overlap and req.finished():
1227
                # Free the one delayed token
1228
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1229
1230
                continue

1231
1232
1233
1234
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1235
1236
1237
            req.check_finished()

            if req.finished():
1238
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1239
1240

            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
1241
1242
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1243
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1244
                    req.output_top_logprobs_val.append(
1245
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1246
1247
                    )
                    req.output_top_logprobs_idx.append(
1248
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1249
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1250

Lianmin Zheng's avatar
Lianmin Zheng committed
1251
1252
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1253
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1254

1255
1256
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1257
            self.current_stream.synchronize()
1258
1259
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1260
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1261

1262
1263
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1264
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1265
        if (
1266
            self.attn_tp_rank == 0
Chayenne's avatar
Chayenne committed
1267
1268
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1269
            self.log_decode_stats()
1270

1271
1272
1273
1274
1275
1276
1277
1278
1279
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1280
1281
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1282
1283
1284
1285

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

Lianmin Zheng's avatar
Lianmin Zheng committed
1286
1287
        if req.input_token_logprobs_val is None:
            input_token_logprobs_val = output.input_token_logprobs[
1288
1289
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1290
1291

            input_token_logprobs_idx = req.fill_ids[
1292
1293
1294
1295
1296
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1297
1298
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1299
            input_token_logprobs_idx = [
1300
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
                for x in input_token_logprobs_idx
1302
1303
            ]

1304
1305
1306
            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
Lianmin Zheng's avatar
Lianmin Zheng committed
1307
1308
1309
1310
1311
                input_token_logprobs_val = [None] + input_token_logprobs_val
                input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx

            req.input_token_logprobs_val = input_token_logprobs_val
            req.input_token_logprobs_idx = input_token_logprobs_idx
1312
1313
1314

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
            req.output_token_logprobs_val.extend(
                output.input_token_logprobs[
                    pt
                    + num_input_logprobs
                    - 1
                    - req.last_update_decode_tokens : pt
                    + num_input_logprobs
                    - 1
                ],
            )
            req.output_token_logprobs_idx.extend(
                req.fill_ids[
                    len(req.fill_ids)
                    - req.last_update_decode_tokens : len(req.fill_ids)
                ]
1330
1331
1332
            )

        if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1333
1334
1335
            if req.input_top_logprobs_val is None:
                req.input_top_logprobs_val = output.input_top_logprobs_val[i]
                req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1336
                if req.logprob_start_len == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1337
1338
                    req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
                    req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1339
1340

            if req.last_update_decode_tokens != 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1341
1342
                req.output_top_logprobs_val.extend(
                    output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1343
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1344
1345
1346
                req.output_top_logprobs_idx.extend(
                    output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
                )
1347
1348
1349

            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1350
1351
1352

        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1353
1354
1355
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1356
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1357
1358
1359
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1360
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
1361
            vids = []
1362
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1363
1364
            decode_ids_list = []
            read_offsets = []
1365
            output_ids = []
1366

Lianmin Zheng's avatar
Lianmin Zheng committed
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
1388
                ) = output_top_logprobs_val = output_top_logprobs_idx = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1389
1390
1391
1392

            for req in reqs:
                if req is skip_req:
                    continue
1393

Lianmin Zheng's avatar
Lianmin Zheng committed
1394
1395
1396
1397
1398
1399
1400
1401
                # TODO(lianmin): revisit this for overlap + retract + stream
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
                    or (not req.stream and len(req.output_ids) % 50 == 0)
                ):
1402
1403
1404
                    if self.draft_worker and req.finished():
                        self.draft_worker.finish_request(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
1405
1406
1407
1408
1409
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
                    vids.append(req.vid)
1410
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1411
1412
1413
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1414
                    if self.skip_tokenizer_init:
1415
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1416
1417
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1418
1419
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)

            # Send to detokenizer
            if rids:
1438
                self.send_to_detokenizer.send_pyobj(
1439
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1440
1441
1442
                        rids,
                        finished_reasons,
                        vids,
1443
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1444
1445
                        decode_ids_list,
                        read_offsets,
1446
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
1461
1462
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
1464
1465
1466
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
1467
1468
1469
1470
1471
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
Lianmin Zheng's avatar
Lianmin Zheng committed
1472
1473
1474
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1475

1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
1502
                    (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1522
            self.spec_algorithm,
1523
1524
1525
1526
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1551
1552
1553
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1554
    def flush_cache(self):
1555
        """Flush the memory pool and cache."""
1556
1557
1558
1559
1560
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1561
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1562
                self.grammar_backend.reset()
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1587
1588
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1589
1590
1591
1592

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1593
                if req.rid == recv_req.rid and not req.finished():
1594
1595
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1596
1597
                    break

Chayenne's avatar
Chayenne committed
1598
1599
1600
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1601
1602
1603
1604
1605
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1606
        return UpdateWeightFromDiskReqOutput(success, message)
1607

1608
1609
1610
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1611
        return InitWeightsUpdateGroupReqOutput(success, message)
1612
1613

    def update_weights_from_distributed(
1614
1615
1616
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1617
1618
1619
1620
1621
1622
1623
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1624
        return UpdateWeightsFromDistributedReqOutput(success, message)
1625

1626
1627
1628
1629
1630
1631
1632
1633
1634
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1635
        return UpdateWeightsFromTensorReqOutput(success, message)
1636

1637
1638
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1639
        return GetWeightsByNameReqOutput(parameter)
1640

1641
1642
1643
1644
1645
1646
    def release_memory_occupation(self):
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
1647
        return ReleaseMemoryOccupationReqOutput()
1648
1649
1650
1651
1652
1653
1654

    def resume_memory_occupation(self):
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
1655
1656
1657
1658
1659
1660
1661
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
        if recv_req == ProfileReq.START_PROFILE:
            self.start_profile()
        else:
            self.stop_profile()
1662

1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1677
    def open_session(self, recv_req: OpenSessionReqInput):
1678
1679
1680
1681
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1682
            return OpenSessionReqOutput(session_id, False)
1683
1684
        elif session_id is None:
            logger.warning(f"session id is None, cannot open.")
1685
            return OpenSessionReqOutput(session_id, False)
1686
1687
1688
1689
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1690
            return OpenSessionReqOutput(session_id, True)
1691
1692
1693
1694
1695
1696
1697
1698
1699

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1700

1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1715
1716
1717
1718
1719
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1720
    dp_rank: Optional[int],
1721
    pipe_writer,
1722
):
1723
    setproctitle.setproctitle("sglang::scheduler")
1724
    faulthandler.enable()
1725

1726
1727
1728
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1729

1730
    # Configue the logger
1731
1732
1733
1734
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1735
    suppress_other_loggers()
1736

1737
    # Set cpu affinity to this gpu process
1738
1739
1740
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1741
    parent_process = psutil.Process().parent()
1742

1743
    # Create a scheduler and run the event loop
1744
    try:
1745
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1746
        pipe_writer.send(
Mick's avatar
Mick committed
1747
1748
1749
1750
1751
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
1752
        )
1753
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1754
1755
1756
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1757
    except Exception:
1758
1759
1760
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)