scheduler.py 73.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import threading
21
22
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
34
import zmq

35
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
36
from sglang.srt.configs.model_config import ModelConfig
37
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
38
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
39
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
40
41
42
43
44
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
45
    CloseSessionReqInput,
46
    FlushCacheReq,
47
48
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
49
50
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
51
52
    OpenSessionReqInput,
    OpenSessionReqOutput,
53
    ProfileReq,
54
55
56
57
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
58
59
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
60
61
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
62
63
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
64
65
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
66
67
68
69
70
71
72
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
73
    global_server_args_dict,
74
)
75
76
77
78
79
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
80
from sglang.srt.managers.session_controller import Session
81
from sglang.srt.managers.tp_worker import TpModelWorker
82
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
from sglang.srt.managers.utils import validate_input_length
84
from sglang.srt.mem_cache.chunk_cache import ChunkCache
85
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
86
from sglang.srt.mem_cache.radix_cache import RadixCache
87
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
88
from sglang.srt.model_executor.forward_batch_info import ForwardMode
89
from sglang.srt.server_args import PortArgs, ServerArgs
90
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
91
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
92
93
94
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
95
    crash_on_warnings,
96
    get_bool_env_var,
97
    get_zmq_socket,
98
    set_gpu_proc_affinity,
99
100
101
    set_random_seed,
    suppress_other_loggers,
)
102
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
103
104
105

logger = logging.getLogger(__name__)

106
# Test retract decode for debugging purposes
107
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
108

109

110
111
112
113
114
115
116
117
118
119
120
121
122
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


123
124
125
126
127
128
129
130
131
class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
132
        dp_rank: Optional[int],
133
134
    ):
        # Parse args
135
        self.server_args = server_args
136
137
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
138
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
139
        self.disable_jump_forward = server_args.disable_jump_forward
140
141
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
142
        self.enable_overlap = not server_args.disable_overlap_schedule
143
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
144
        self.enable_metrics = server_args.enable_metrics
145
146
147
148
149
150
151
152
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
        self.decode_mem_cache_buf_multiplier = (
            self.server_args.speculative_num_draft_tokens
            if not self.spec_algorithm.is_none()
            else 1
        )
153
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
154

155
        # Distributed rank info
156
157
158
159
160
161
162
163
164
165
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

166
167
        # Init inter-process communication
        context = zmq.Context(2)
168
        if self.attn_tp_rank == 0:
169
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
170
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
171
            )
172
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
173
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
174
            )
175

176
            if server_args.skip_tokenizer_init:
177
                # Directly send to the TokenizerManager
178
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
179
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
180
181
                )
            else:
182
                # Send to the DetokenizerManager
183
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
184
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
185
                )
186
        else:
187
            self.recv_from_tokenizer = None
188
189
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
190
191
192
193

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
194
            trust_remote_code=server_args.trust_remote_code,
195
            revision=server_args.revision,
196
            context_length=server_args.context_length,
197
198
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
199
200
            dtype=server_args.dtype,
            quantization=server_args.quantization,
201
        )
202
        self.is_generation = self.model_config.is_generation
203
204
205
206

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
207
            if self.model_config.is_multimodal:
208
209
210
211
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
212
                    revision=server_args.revision,
213
214
215
216
217
218
219
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
220
                    revision=server_args.revision,
221
                )
222

223
224
225
226
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
227

228
229
230
231
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

232
233
        if self.enable_overlap:
            self.disable_jump_forward = True
234

235
        # Launch a tensor parallel worker
236
        if self.enable_overlap:
237
            TpWorkerClass = TpModelWorkerClient
238
239
        else:
            TpWorkerClass = TpModelWorker
240

241
        self.tp_worker = TpWorkerClass(
242
            server_args=server_args,
243
244
            gpu_id=gpu_id,
            tp_rank=tp_rank,
245
            dp_rank=dp_rank,
246
            nccl_port=port_args.nccl_port,
247
        )
248

249
        # Launch a worker for speculative decoding if needed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

264
        # Get token and memory info from the model worker
265
266
267
268
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
269
            self.max_req_len,
270
271
            self.max_req_input_len,
            self.random_seed,
272
            self.device,
273
274
275
276
277
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
278
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
279
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
280
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
281
        global_server_args_dict.update(worker_global_server_args_dict)
282
283
284
285
        set_random_seed(self.random_seed)
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
286
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
287
288
289
290
291
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

292
293
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
294
295
296
297
298
299
300
301
302
303

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
304
305
306
307
308
309
310
311
312
313
314
            self.tree_cache = (
                HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool,
                )
                if self.enable_hierarchical_cache
                else RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool,
                    disable=server_args.disable_radix_cache,
                )
315
316
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
317
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
318
319
320

        # Init running status
        self.waiting_queue: List[Req] = []
321
        self.staging_reqs = {}
322
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
323
        self.running_batch: Optional[ScheduleBatch] = None
324
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
325
        self.cur_batch: Optional[ScheduleBatch] = None
326
327
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
328
329
        self.forward_ct = 0
        self.forward_ct_decode = 0
330
        self.num_generated_tokens = 0
331
332
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
333
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
334
        self.stream_interval = server_args.stream_interval
335
        self.current_stream = torch.get_device_module(self.device).current_stream()
336
337
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
338
339

        # Session info
340
        self.sessions: Dict[str, Session] = {}
341
342
343

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
344
345
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
346
        self.being_chunked_req = None
347
348
349
350
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
351
        # Init the grammar backend for constrained generation
352
        self.grammar_queue: List[Req] = []
353
        if not server_args.skip_tokenizer_init:
354
355
356
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
        else:
            self.grammar_backend = None
359
360

        # Init new token estimation
361
362
363
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
364
365
366

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
367
368
            * server_args.schedule_conservativeness,
            1.0,
369
        )
370
371
372
373
374
375
376
377
378
379
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
382
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
383
        self.batch_is_full = False
384

Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
387
388
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
389
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
390

391
392
393
394
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

395
        # Init profiler
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
411

412
        # Init metrics stats
413
414
415
416
417
418
419
420
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
421

422
423
424
425
426
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0

427
428
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
                (ProfileReq, self.profile),
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
                (
                    ReleaseMemoryOccupationReqInput,
                    lambda _: self.release_memory_occupation(),
                ),
                (
                    ResumeMemoryOccupationReqInput,
                    lambda _: self.resume_memory_occupation(),
                ),
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
456
    def watchdog_thread(self):
457
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
460
461
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
462
            current = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
463
464
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
465
                    if current > self.watchdog_last_time + self.watchdog_timeout:
Lianmin Zheng's avatar
Lianmin Zheng committed
466
467
468
469
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
470
471
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)
472
473
        # Wait sometimes so that the parent process can print the error.
        time.sleep(5)
474
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
475

476
    @torch.no_grad()
477
    def event_loop_normal(self):
478
        """A normal scheduler loop."""
479
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
480
481
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
482

483
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
484
            self.cur_batch = batch
485
486
487
488

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
489
            else:
490
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
491
                self.check_memory()
492
                self.new_token_ratio = self.init_new_token_ratio
493
494

            self.last_batch = batch
495

496
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
497
    def event_loop_overlap(self):
498
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
499
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
500
501
502
503
504
505
506

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
507

Lianmin Zheng's avatar
Lianmin Zheng committed
508
509
            if batch:
                result = self.run_batch(batch)
510
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
511

512
                if self.last_batch is None:
513
                    # Create a dummy first batch to start the pipeline for overlap schedule.
514
515
516
517
518
519
520
521
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
522
            if self.last_batch:
523
                # Process the results of the last batch
524
                tmp_batch, tmp_result = self.result_queue.popleft()
525
526
527
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
528
529
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
530
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
531
                self.check_memory()
532
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
533
534
535

            self.last_batch = batch

536
537
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
538
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
539
540
            recv_reqs = []

541
542
543
544
545
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
546
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
547
548
        else:
            recv_reqs = None
549

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
584
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
585
586
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
587
    def process_input_requests(self, recv_reqs: List):
588
        for recv_req in recv_reqs:
589
            output = self._request_dispatcher(recv_req)
590
591
            if output is not None:
                self.send_to_tokenizer.send_pyobj(output)
592
593
594
595
596

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
597
        # Create a new request
598
599
600
601
602
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
603

Rin Intachuen's avatar
Rin Intachuen committed
604
605
606
607
608
609
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

610
611
612
613
614
615
616
617
618
619
620
621
622
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

623
624
625
626
627
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
                stream=recv_req.stream,
631
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
632
                input_embeds=recv_req.input_embeds,
633
                custom_logit_processor=custom_logit_processor,
634
                eos_token_ids=self.model_config.hf_eos_token_id,
635
636
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
637

638
639
640
641
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
642
                req.finished_reason = FINISH_ABORT(
643
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
644
645
646
647
                )
                self.waiting_queue.append(req)
                return
        else:
648
649
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
650
            req = session.create_req(recv_req, self.tokenizer)
651
652
653
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
654

655
        # Handle multimodal inputs
656
        if recv_req.image_inputs is not None:
657
658
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
659
            req.origin_input_ids = self.pad_input_ids_func(
660
                req.origin_input_ids, image_inputs
661
            )
662
            req.extend_image_inputs(image_inputs)
663

664
            if len(req.origin_input_ids) >= self.max_req_input_len:
665
                error_msg = (
666
                    "Multimodal prompt is too long after expanding multimodal tokens. "
667
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
668
                )
669
                logger.error(error_msg)
670
                req.origin_input_ids = [0]
671
                req.image_inputs = None
672
                req.sampling_params.max_new_tokens = 0
673
                req.finished_reason = FINISH_ABORT(
674
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
675
                )
676
677
678
                self.waiting_queue.append(req)
                return

679
680
681
682
683
684
685
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
686
687
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
688
689
            self.waiting_queue.append(req)
            return
690

691
692
693
694
695
696
697
        # Copy more attributes
        if recv_req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

698
699
700
701
702
703
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
704
            self.max_req_len - len(req.origin_input_ids) - 1,
705
706
        )

707
708
709
710
711
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
712
            or req.sampling_params.ebnf is not None
713
            or req.sampling_params.structural_tag is not None
714
715
716
717
718
719
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
720
721
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
722
723
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
724
725
726
727
728
729
730

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
731
732
733
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
734
735
736

    def handle_embedding_request(
        self,
737
        recv_req: TokenizedEmbeddingReqInput,
738
739
740
741
742
743
744
745
746
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

747
        # Validate prompts length
748
        error_msg = validate_input_length(
749
750
751
752
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
753
754
755
        if error_msg:
            self.waiting_queue.append(req)
            return
756

757
758
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
759
760
        self.waiting_queue.append(req)

761
762
763
764
765
766
767
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
        running_bs: ScheduleBatch,
        has_being_chunked: bool,
    ):
768
769
770
771
772
773
774
        self.tree_cache_metrics["total"] += (
            adder.log_input_tokens + adder.log_hit_tokens
        ) / 10**9
        self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
        tree_cache_hit_rate = (
            self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
        )
775
776
777
778
779
780
781
782
783
784
785
786
787

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
788
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
789
790
791
792
793
794
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
795
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
796
797
798
799
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
800
801
802
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
803
804
805
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
806
        self.num_generated_tokens = 0
807
        self.last_decode_stats_tic = time.time()
808
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
809

810
811
812
813
814
815
816
817
818
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
                f"gen throughput (token/s): {gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}"
            )
819
            spec_accept_length = 0
820
        else:
821
            spec_accept_length = (
822
823
824
825
826
827
828
829
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
830
                f"accept len: {spec_accept_length:.2f}, "
831
832
833
834
835
                f"gen throughput (token/s): {gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}"
            )

        logger.info(msg)
836
837
838
839
840
841
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
842
            self.stats.spec_accept_length = spec_accept_length
843
844
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
845
846
847
848
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
849
850
851
852
853
854
855
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
856
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
857
                "KV cache pool leak detected!"
858
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
859
            )
860
861
862
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
863
864

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
865
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
866
                "Memory pool leak detected!"
867
868
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
869
            )
870
871
872
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
873

874
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
875
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
876
        if self.last_batch and self.last_batch.forward_mode.is_extend():
877
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
878
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
879
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
880
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
881
                # being chunked request keeps its rid but will get a new req_pool_idx
882
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
883
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
884

885
886
887
888
889
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
890

891
892
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
893
894
895
896
897
898
899
900
901
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
            if self.running_batch is None:
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch
902

903
904
905
906
907
        # Handle DP attention
        if self.server_args.enable_dp_attention:
            ret = self.prepare_dp_attn_batch(ret)

        return ret
908

Lianmin Zheng's avatar
Lianmin Zheng committed
909
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
910
        # Check if the grammar is ready in the grammar queue
911
        if self.grammar_queue:
912
            self.move_ready_grammar_requests()
913

Lianmin Zheng's avatar
Lianmin Zheng committed
914
915
916
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
917
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
918
919
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
920
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
921
        if running_bs >= self.max_running_requests:
922
            self.batch_is_full = True
923
924
925
926
927
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
928
        # Prefill policy
929
930
        adder = PrefillAdder(
            self.tree_cache,
931
            self.token_to_kv_pool,
932
933
934
935
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
936
            running_bs if self.is_mixed_chunk else 0,
937
938
        )

939
940
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
941
            self.being_chunked_req.init_next_round_input()
942
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
943

Lianmin Zheng's avatar
Lianmin Zheng committed
944
        if self.lora_paths:
945
946
947
948
949
950
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

951
        # Get requests from the waiting queue to a new prefill batch
952
953
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
954
                self.lora_paths
955
956
957
958
959
960
961
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
962
                self.batch_is_full = True
963
964
                break

965
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
966
                self.batch_is_full = True
967
                break
968

969
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993

            if self.enable_hierarchical_cache and req.last_node is not None:
                if req.last_node.evicted:
                    # loading KV cache for the request
                    req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
                        req.last_node,
                        req.prefix_indices,
                        adder.rem_total_tokens,
                    )
                    if req.last_node.loading:
                        # to prevent frequent cache invalidation
                        if req.rid in self.staging_reqs:
                            self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                        self.tree_cache.inc_lock_ref(req.last_node)
                        self.staging_reqs[req.rid] = req.last_node
                        continue
                elif req.last_node.loading:
                    if not self.tree_cache.loading_complete(req.last_node):
                        continue

                if req.rid in self.staging_reqs:
                    self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
                    del self.staging_reqs[req.rid]

994
            res = adder.add_one_req(req)
995
996
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
997
998
999
1000
1001
1002
1003
1004
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
                        self.batch_is_full = len(adder.can_run_list) > 0 or (
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
                        self.batch_is_full = True
1005
                break
1006
1007
            if self.server_args.prefill_only_one_req:
                break
1008

Lianmin Zheng's avatar
Lianmin Zheng committed
1009
        # Update waiting queue
1010
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1011
1012
1013
1014
1015
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1016

1017
        if adder.new_being_chunked_req is not None:
1018
            assert self.being_chunked_req is None
1019
            self.being_chunked_req = adder.new_being_chunked_req
1020

1021
1022
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1023

1024
        # Print stats
1025
        if self.attn_tp_rank == 0:
1026
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
1027

Lianmin Zheng's avatar
Lianmin Zheng committed
1028
        # Create a new batch
1029
1030
1031
1032
1033
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
1034
            self.model_config,
1035
            self.enable_overlap,
1036
            self.spec_algorithm,
1037
            self.server_args.enable_custom_logit_processor,
1038
        )
1039
        new_batch.prepare_for_extend()
1040

Lianmin Zheng's avatar
Lianmin Zheng committed
1041
        # Mixed-style chunked prefill
1042
1043
1044
1045
1046
1047
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1048
1049
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1050
                self.running_batch.prepare_for_decode()
1051
1052
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
1053
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1054
1055
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1056
1057
1058

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1059
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1060
        """Update the current running decoding batch."""
1061
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
1062
1063

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1064

1065
1066
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1067
1068
            self.batch_is_full = False
            return None
1069

Lianmin Zheng's avatar
Lianmin Zheng committed
1070
        # Check if decode out of memory
1071
1072
1073
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
            test_retract and batch.batch_size() > 10
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1074
1075
1076
1077
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
1078
1079
            if self.draft_worker:
                self.draft_worker.finish_request(retracted_reqs)
1080

Lianmin Zheng's avatar
Lianmin Zheng committed
1081
1082
1083
1084
1085
1086
1087
1088
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
1089
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1090
1091
1092
1093
                self.min_new_token_ratio,
            )

        # Check for jump-forward
1094
        if not self.disable_jump_forward and batch.has_grammar:
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
1096
1097
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1098
1099
1100
1101
1102
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1103
1104

        # Update batch tensors
1105
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1106
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1107

1108
1109
1110
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1111
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1112
1113
        self.forward_ct += 1

1114
        if self.is_generation:
1115
1116
1117
1118
1119
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1120
            else:
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
                (
                    logits_output,
                    next_token_ids,
                    model_worker_batch,
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1132
            batch.output_ids = next_token_ids
1133
1134
1135
1136
1137
1138

            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
                bid=model_worker_batch.bid,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
1140
1141
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1142
1143
1144
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1145
        return ret
Chayenne's avatar
Chayenne committed
1146

1147
1148
1149
1150
1151
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1152
1153
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1154
1155
            if batch.is_empty():
                self.running_batch = None
1156
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1157
            self.process_batch_result_prefill(batch, result)
1158
1159
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1160
                self.tp_worker.resolve_batch_result(result.bid)
1161
1162
1163
1164
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1165
1166
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1167
            self.current_stream.synchronize()
1168
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1169

1170
1171
1172
1173
1174
    def process_batch_result_prefill(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
1175
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1176

Lianmin Zheng's avatar
Lianmin Zheng committed
1177
        if self.is_generation:
1178
1179
1180
1181
1182
1183
1184
1185
1186
            (
                logits_output,
                next_token_ids,
                bid,
            ) = (
                result.logits_output,
                result.next_token_ids,
                result.bid,
            )
1187
1188

            if self.enable_overlap:
1189
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1190
1191
            else:
                # Move next_token_ids and logprobs to cpu
1192
                next_token_ids = next_token_ids.tolist()
1193
                if batch.return_logprob:
1194
                    logits_output.next_token_logprobs = (
1195
                        logits_output.next_token_logprobs.tolist()
1196
1197
1198
1199
1200
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )

1201
1202
            hidden_state_offset = 0

1203
1204
            # Check finish conditions
            logprob_pt = 0
1205
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1206
1207
1208
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1209
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1210
1211
1212
1213
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1214

1215
                if req.is_being_chunked <= 0:
1216
                    req.output_ids.append(next_token_id)
1217
1218
                    req.check_finished()

1219
                    if req.finished():
1220
                        self.tree_cache.cache_finished_req(req)
1221
1222
1223
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1224
1225
1226
1227
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
1228
                    if (
1229
                        req.sampling_params.return_hidden_states
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
                        and logits_output.hidden_states is not None
                    ):
                        req.hidden_states.append(
                            logits_output.hidden_states[
                                hidden_state_offset : (
                                    hidden_state_offset := hidden_state_offset
                                    + len(req.origin_input_ids)
                                )
                            ]
                            .cpu()
                            .clone()
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1243
1244
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1245
                        req.grammar.finished = req.finished()
1246
                else:
1247
                    # being chunked reqs' prefill is not finished
1248
                    req.is_being_chunked -= 1
1249
1250
1251
1252
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1253

1254
1255
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1256
                self.current_stream.synchronize()
1257
1258
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1259
        else:  # embedding or reward model
1260
            embeddings, bid = result.embeddings, result.bid
1261
            embeddings = embeddings.tolist()
1262
1263
1264

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1265
1266
1267
                if req.is_retracted:
                    continue

1268
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1269
1270
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1271
1272
1273
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1274
1275
1276
1277
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1278
                else:
1279
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1280
                    req.is_being_chunked -= 1
1281

Lianmin Zheng's avatar
Lianmin Zheng committed
1282
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1283

1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
    def process_batch_result_decode(
        self,
        batch: ScheduleBatch,
        result: GenerationBatchResult,
    ):
        logits_output, next_token_ids, bid = (
            result.logits_output,
            result.next_token_ids,
            result.bid,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1294
1295
        self.num_generated_tokens += len(batch.reqs)

1296
        if self.enable_overlap:
1297
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1298
            next_token_logprobs = logits_output.next_token_logprobs
1299
1300
        else:
            next_token_ids = next_token_ids.tolist()
1301
1302
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1303

1304
1305
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1306
1307
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1308
1309
1310
            if req.is_retracted:
                continue

1311
            if self.enable_overlap and req.finished():
1312
                # Free the one delayed token
1313
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1314
1315
                continue

1316
1317
1318
1319
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1320
1321
1322
            req.check_finished()

            if req.finished():
1323
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1324
1325

            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
1326
1327
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1328
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1329
                    req.output_top_logprobs_val.append(
1330
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1331
1332
                    )
                    req.output_top_logprobs_idx.append(
1333
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1334
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1335

1336
            if (
1337
                req.sampling_params.return_hidden_states
1338
1339
1340
1341
                and logits_output.hidden_states is not None
            ):
                req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())

Lianmin Zheng's avatar
Lianmin Zheng committed
1342
1343
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1344
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1345

1346
1347
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1348
            self.current_stream.synchronize()
1349
1350
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1351
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1352

1353
1354
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1355
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1356
        if (
1357
            self.attn_tp_rank == 0
Chayenne's avatar
Chayenne committed
1358
1359
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1360
            self.log_decode_stats()
1361

1362
1363
1364
1365
1366
1367
1368
1369
1370
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1371
1372
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1373
1374
1375
1376

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

Lianmin Zheng's avatar
Lianmin Zheng committed
1377
1378
        if req.input_token_logprobs_val is None:
            input_token_logprobs_val = output.input_token_logprobs[
1379
1380
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1381
1382

            input_token_logprobs_idx = req.fill_ids[
1383
1384
1385
1386
1387
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1388
1389
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1390
            input_token_logprobs_idx = [
1391
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1392
                for x in input_token_logprobs_idx
1393
1394
            ]

1395
1396
1397
            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
Lianmin Zheng's avatar
Lianmin Zheng committed
1398
1399
1400
1401
1402
                input_token_logprobs_val = [None] + input_token_logprobs_val
                input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx

            req.input_token_logprobs_val = input_token_logprobs_val
            req.input_token_logprobs_idx = input_token_logprobs_idx
1403
1404
1405

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
            req.output_token_logprobs_val.extend(
                output.input_token_logprobs[
                    pt
                    + num_input_logprobs
                    - 1
                    - req.last_update_decode_tokens : pt
                    + num_input_logprobs
                    - 1
                ],
            )
            req.output_token_logprobs_idx.extend(
                req.fill_ids[
                    len(req.fill_ids)
                    - req.last_update_decode_tokens : len(req.fill_ids)
                ]
1421
1422
1423
            )

        if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1424
1425
1426
            if req.input_top_logprobs_val is None:
                req.input_top_logprobs_val = output.input_top_logprobs_val[i]
                req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1427
                if req.logprob_start_len == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1428
1429
                    req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
                    req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1430
1431

            if req.last_update_decode_tokens != 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1432
1433
                req.output_top_logprobs_val.extend(
                    output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1434
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1435
1436
1437
                req.output_top_logprobs_idx.extend(
                    output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
                )
1438
1439
1440

            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1441
1442
1443

        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1444
1445
1446
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1447
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1448
1449
1450
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1451
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
1452
            vids = []
1453
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1454
1455
            decode_ids_list = []
            read_offsets = []
1456
            output_ids = []
1457

Lianmin Zheng's avatar
Lianmin Zheng committed
1458
1459
1460
1461
1462
1463
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []
1464
            spec_verify_ct = []
1465
1466
1467
1468
            return_hidden_states = any(
                req.sampling_params.return_hidden_states for req in reqs
            )
            output_hidden_states = [] if return_hidden_states else None
Lianmin Zheng's avatar
Lianmin Zheng committed
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
1484
                ) = output_top_logprobs_val = output_top_logprobs_idx = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1485
1486
1487
1488

            for req in reqs:
                if req is skip_req:
                    continue
1489

Lianmin Zheng's avatar
Lianmin Zheng committed
1490
1491
1492
1493
1494
1495
1496
1497
                # TODO(lianmin): revisit this for overlap + retract + stream
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
                    or (not req.stream and len(req.output_ids) % 50 == 0)
                ):
1498
1499
1500
                    if self.draft_worker and req.finished():
                        self.draft_worker.finish_request(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
1501
1502
1503
1504
1505
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
                    vids.append(req.vid)
1506
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1507
1508
1509
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1510
                    if self.skip_tokenizer_init:
1511
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1512
1513
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1514
1515
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1516
1517
1518
1519
1520
1521
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

1522
1523
1524
                    if not self.spec_algorithm.is_none():
                        spec_verify_ct.append(req.spec_verify_ct)

Lianmin Zheng's avatar
Lianmin Zheng committed
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)

1535
                    if req.sampling_params.return_hidden_states:
1536
                        output_hidden_states.append(req.hidden_states)
1537

Lianmin Zheng's avatar
Lianmin Zheng committed
1538
1539
            # Send to detokenizer
            if rids:
1540
                self.send_to_detokenizer.send_pyobj(
1541
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1542
1543
1544
                        rids,
                        finished_reasons,
                        vids,
1545
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1546
1547
                        decode_ids_list,
                        read_offsets,
1548
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1549
1550
1551
1552
1553
1554
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
1555
                        spec_verify_ct,
Lianmin Zheng's avatar
Lianmin Zheng committed
1556
1557
1558
1559
1560
1561
1562
1563
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
1564
                        output_hidden_states,
1565
1566
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1567
1568
1569
1570
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
1571
1572
1573
1574
1575
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
Lianmin Zheng's avatar
Lianmin Zheng committed
1576
1577
1578
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1579

1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
1606
                    (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1626
            self.spec_algorithm,
1627
            self.server_args.enable_custom_logit_processor,
1628
1629
1630
1631
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
        if self.server_args.enable_dp_attention:
            if self.attn_tp_size > 1:
                # Sync across attn TP ranks to make sure they have the same number of ready requests
                tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
                torch.distributed.all_reduce(
                    tensor,
                    op=torch.distributed.ReduceOp.MAX,
                    group=self.attn_tp_cpu_group,
                )
                num_ready_reqs_max = tensor.item()
                for i in range(num_ready_reqs, num_ready_reqs_max):
                    self.grammar_queue[i].grammar = self.grammar_queue[
                        i
                    ].grammar.result()
                num_ready_reqs = num_ready_reqs_max
        else:
            if self.tp_size > 1:
                # Sync across TP ranks to make sure they have the same number of ready requests
                tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
                torch.distributed.all_reduce(
                    tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
                )
                num_ready_reqs_max = tensor.item()
                for i in range(num_ready_reqs, num_ready_reqs_max):
                    self.grammar_queue[i].grammar = self.grammar_queue[
                        i
                    ].grammar.result()
                num_ready_reqs = num_ready_reqs_max
1670
1671
1672
1673

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1674
1675
1676
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1677
    def flush_cache(self):
1678
        """Flush the memory pool and cache."""
1679
1680
1681
1682
1683
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1684
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1685
                self.grammar_backend.reset()
1686
1687
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
1688
1689
1690
1691
1692
1693
1694
1695
1696

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
                self.draft_worker.model_runner.token_to_kv_pool.clear()

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1719
1720
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1721
1722
1723
1724

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1725
                if req.rid == recv_req.rid and not req.finished():
1726
1727
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1728
1729
                    break

Chayenne's avatar
Chayenne committed
1730
1731
1732
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1733
1734
1735
1736
1737
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1738
        return UpdateWeightFromDiskReqOutput(success, message)
1739

1740
1741
1742
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1743
        return InitWeightsUpdateGroupReqOutput(success, message)
1744
1745

    def update_weights_from_distributed(
1746
1747
1748
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1749
1750
1751
1752
1753
1754
1755
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1756
        return UpdateWeightsFromDistributedReqOutput(success, message)
1757

1758
1759
1760
1761
1762
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
1763
1764
1765
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
1766
1767
        else:
            logger.error(message)
1768
        return UpdateWeightsFromTensorReqOutput(success, message)
1769

1770
1771
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1772
        return GetWeightsByNameReqOutput(parameter)
1773

1774
1775
1776
1777
1778
1779
    def release_memory_occupation(self):
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
1780
        return ReleaseMemoryOccupationReqOutput()
1781
1782
1783
1784
1785
1786
1787

    def resume_memory_occupation(self):
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
1788
1789
1790
1791
1792
1793
1794
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
        if recv_req == ProfileReq.START_PROFILE:
            self.start_profile()
        else:
            self.stop_profile()
1795

1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1810
    def open_session(self, recv_req: OpenSessionReqInput):
1811
1812
1813
1814
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1815
            return OpenSessionReqOutput(session_id, False)
1816
1817
        elif session_id is None:
            logger.warning(f"session id is None, cannot open.")
1818
            return OpenSessionReqOutput(session_id, False)
1819
1820
1821
1822
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1823
            return OpenSessionReqOutput(session_id, True)
1824
1825
1826
1827
1828
1829
1830
1831
1832

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1833

1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1848
1849
1850
1851
1852
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1853
    dp_rank: Optional[int],
1854
    pipe_writer,
1855
):
1856
    setproctitle.setproctitle("sglang::scheduler")
1857
    faulthandler.enable()
1858

1859
1860
1861
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1862

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1863
    # Configure the logger
1864
1865
1866
1867
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1868
    suppress_other_loggers()
1869

1870
    # Set cpu affinity to this gpu process
1871
1872
1873
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1874
    parent_process = psutil.Process().parent()
1875

1876
    # Create a scheduler and run the event loop
1877
    try:
1878
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1879
        pipe_writer.send(
Mick's avatar
Mick committed
1880
1881
1882
1883
1884
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
1885
        )
1886
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1887
1888
1889
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1890
    except Exception:
1891
1892
1893
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)