scheduler.py 60.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""A scheduler that manages a tensor parallel GPU worker."""

import logging
17
import os
18
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
19
import threading
20
21
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from concurrent import futures
24
from types import SimpleNamespace
25
from typing import Callable, Dict, List, Optional, Tuple
26

27
import psutil
28
import setproctitle
29
import torch
30
31
import zmq

32
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.configs.model_config import ModelConfig
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
40
    CloseSessionReqInput,
41
    FlushCacheReq,
42
43
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
44
45
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
46
47
    OpenSessionReqInput,
    OpenSessionReqOutput,
48
    ProfileReq,
49
50
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
51
52
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
53
54
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
55
56
57
58
59
60
61
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
62
    global_server_args_dict,
63
)
64
65
66
67
68
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
69
from sglang.srt.managers.session_controller import Session
70
from sglang.srt.managers.tp_worker import TpModelWorker
71
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
72
73
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
74
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
75
from sglang.srt.model_executor.forward_batch_info import ForwardMode
76
from sglang.srt.server_args import PortArgs, ServerArgs
77
78
79
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
80
    crash_on_warnings,
81
    get_bool_env_var,
82
    get_zmq_socket,
83
    set_gpu_proc_affinity,
84
85
86
    set_random_seed,
    suppress_other_loggers,
)
87
88
89
90
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

91
# Test retract decode
92
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
93

94
95
96
97
98
99
100
101
102
103

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
104
        dp_rank: Optional[int],
105
106
    ):
        # Parse args
107
        self.server_args = server_args
108
109
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
110
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
111
        self.disable_jump_forward = server_args.disable_jump_forward
112
113
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
114
        self.enable_overlap = not server_args.disable_overlap_schedule
115
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
116
        self.enable_metrics = server_args.enable_metrics
117
118
119
120

        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
121
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
122
123
124
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
125
126
127
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
128

129
130
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
131
132
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
133
134
135
                )
            else:
                # Send to the detokenizer
136
137
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
138
                )
139
        else:
140
            self.recv_from_tokenizer = None
141
142
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
143
144
145
146

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
147
            trust_remote_code=server_args.trust_remote_code,
148
            revision=server_args.revision,
149
            context_length=server_args.context_length,
150
151
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
152
153
            dtype=server_args.dtype,
            quantization=server_args.quantization,
154
        )
155
        self.is_generation = self.model_config.is_generation
156
157
158
159

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
160
            if self.model_config.is_multimodal:
161
162
163
164
165
166
167
168
169
170
171
172
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
173

174
175
176
177
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
178

179
180
181
182
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

183
184
        if self.enable_overlap:
            self.disable_jump_forward = True
185

186
        # Launch a tensor parallel worker
187
        if self.enable_overlap:
188
            TpWorkerClass = TpModelWorkerClient
189
190
        else:
            TpWorkerClass = TpModelWorker
191

192
        self.tp_worker = TpWorkerClass(
193
            server_args=server_args,
194
195
            gpu_id=gpu_id,
            tp_rank=tp_rank,
196
            dp_rank=dp_rank,
197
            nccl_port=port_args.nccl_port,
198
        )
199

200
        # Get token and memory info from the model worker
201
202
203
204
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
205
            self.max_req_len,
206
207
            self.max_req_input_len,
            self.random_seed,
208
            self.device,
209
210
211
212
213
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
214
215
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
216
        global_server_args_dict.update(worker_global_server_args_dict)
217
218
219
220
221
222
223
224
225
226
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

227
228
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
245
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
246
247
248

        # Init running status
        self.waiting_queue: List[Req] = []
249
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
250
        self.running_batch: Optional[ScheduleBatch] = None
251
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
252
        self.cur_batch: Optional[ScheduleBatch] = None
253
254
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
        self.forward_ct = 0
        self.forward_ct_decode = 0
257
        self.num_generated_tokens = 0
258
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
259
        self.stream_interval = server_args.stream_interval
260
261
262
        self.current_stream = torch.get_device_module(self.device).current_stream()

        # Session info
263
        self.sessions: Dict[str, Session] = {}
264
265
266

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
267
268
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
269
        self.being_chunked_req = None
270
271
272
273
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
274
        # Init the grammar backend for constrained generation
275
        self.grammar_queue: List[Req] = []
276
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279
280
281
282
283
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
284
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
301
302

        # Init new token estimation
303
304
305
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
306
307
308

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
309
310
            * server_args.schedule_conservativeness,
            1.0,
311
        )
312
313
314
315
316
317
318
319
320
321
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
324
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
325
        self.batch_is_full = False
326

Lianmin Zheng's avatar
Lianmin Zheng committed
327
328
329
330
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
331
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
332

333
        # Init profiler
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
349

350
        # Init metrics stats
351
352
353
354
355
356
357
358
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
359

Lianmin Zheng's avatar
Lianmin Zheng committed
360
    def watchdog_thread(self):
361
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

376
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
377

378
    @torch.no_grad()
379
    def event_loop_normal(self):
380
        """A normal scheduler loop."""
381
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
384

385
            batch = self.get_next_batch_to_run()
Ke Bao's avatar
Ke Bao committed
386
387
388
            if self.server_args.enable_dp_attention:
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
389
            self.cur_batch = batch
390
391
392
393

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
394
            else:
395
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
396
                self.check_memory()
397
                self.new_token_ratio = self.init_new_token_ratio
398
399

            self.last_batch = batch
400

401
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
402
    def event_loop_overlap(self):
403
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
408
409
410
411
412
413
414
415
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

416
417
418
419
420
421
422
423
424
425
                if self.last_batch is None:
                    # A dummy first batch to start the pipeline for overlap scheduler.
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
426
427
            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
428
429
430
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
431
432
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
433
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
434
                self.check_memory()
435
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
436
437
438

            self.last_batch = batch

Lianmin Zheng's avatar
Lianmin Zheng committed
439
    def recv_requests(self):
Ke Bao's avatar
Ke Bao committed
440
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
            recv_reqs = []

443
444
445
446
447
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
448
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
449
450
        else:
            recv_reqs = None
451

Ke Bao's avatar
Ke Bao committed
452
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
453
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
454
455
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
456
    def process_input_requests(self, recv_reqs: List):
457
458
459
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
460
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
461
462
463
464
465
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
Chayenne's avatar
Chayenne committed
466
467
            elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
                success, message = self.update_weights_from_disk(recv_req)
468
                self.send_to_tokenizer.send_pyobj(
Chayenne's avatar
Chayenne committed
469
                    UpdateWeightFromDiskReqOutput(success, message)
470
                )
471
472
473
474
475
476
477
478
479
480
481
482
483
            elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
                success, message = self.init_weights_update_group(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    InitWeightsUpdateGroupReqOutput(success, message)
                )
            elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
                success, message = self.update_weights_from_distributed(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromDistributedReqOutput(success, message)
                )
            elif isinstance(recv_req, GetWeightsByNameReqInput):
                parameter = self.get_weights_by_name(recv_req)
                self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
484
485
486
487
488
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
489
490
491
492
493
            elif isinstance(recv_req, OpenSessionReqInput):
                session_id = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
494
495
496
497
498
499
500
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
501
        # Create a new request
502
        if recv_req.session_id is None or recv_req.session_id not in self.sessions:
503

Rin Intachuen's avatar
Rin Intachuen committed
504
505
506
507
508
509
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

510
511
512
513
514
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
515
516
517
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
                stream=recv_req.stream,
518
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
519
                input_embeds=recv_req.input_embeds,
520
                eos_token_ids=self.model_config.get_hf_eos_token_id(),
521
522
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
523

524
525
526
527
528
529
530
            if recv_req.session_id is not None:
                req.finished_reason = FINISH_ABORT(
                    f"Invalid request: session id {recv_req.session_id} does not exist"
                )
                self.waiting_queue.append(req)
                return
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
531
            # Create a new request from a previsou session
532
            session = self.sessions[recv_req.session_id]
533
            req = session.create_req(recv_req, self.tokenizer)
534
535
536
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
537

538
        # Handle image inputs
539
        if recv_req.image_inputs is not None:
540
541
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
542
            req.origin_input_ids = self.pad_input_ids_func(
543
                req.origin_input_ids, image_inputs
544
            )
545
            req.extend_image_inputs(image_inputs)
546

547
548
549
550
            if len(req.origin_input_ids) >= self.max_req_input_len:
                logger.error(
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
551
                )
552
                req.origin_input_ids = [0]
553
                req.image_inputs = None
554
                req.sampling_params.max_new_tokens = 0
555
556
557
                req.finished_reason = FINISH_ABORT(
                    "Multimodal prompt is too long. Check server logs for details."
                )
558
559
560
                self.waiting_queue.append(req)
                return

561
        # Copy more attributes
562
563
564
565
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
566
            req.logprob_start_len = len(req.origin_input_ids) - 1
567
568

        # Truncate prompts that are too long
569
        if len(req.origin_input_ids) > self.max_req_input_len:
570
571
572
573
574
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
575

576
577
578
579
580
581
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
582
            self.max_req_len - len(req.origin_input_ids) - 1,
583
584
        )

585
586
587
588
589
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
590
            or req.sampling_params.ebnf is not None
591
592
593
594
595
596
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
597
598
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
599
600
601
602
603
604
605

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
606
607
608
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
609
610
611

    def handle_embedding_request(
        self,
612
        recv_req: TokenizedEmbeddingReqInput,
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

632
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
633
634
635
636
637
638
639
        self.tree_cache_metrics["total"] += (
            adder.log_input_tokens + adder.log_hit_tokens
        ) / 10**9
        self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
        tree_cache_hit_rate = (
            self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
        )
640
641
642
643
644
645
646
647
648
649
650
651
652

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
653
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
654
655
656
657
658
659
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
660
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
661
662
663
664
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
667
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
668
669
670
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
671
        self.num_generated_tokens = 0
672
        self.last_decode_stats_tic = time.time()
673
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
        logger.info(
            f"Decode batch. "
676
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
677
678
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
679
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
680
681
682
            f"#queue-req: {len(self.waiting_queue)}"
        )

683
684
685
686
687
688
689
690
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
691
692
693
694
695
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
696
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
697
                "KV cache pool leak detected!"
698
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
699
            )
700
701
702
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
703
704

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
705
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
706
                "Memory pool leak detected!"
707
708
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
709
            )
710
711
712
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
713

714
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
715
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
716
        if self.last_batch and self.last_batch.forward_mode.is_extend():
717
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
718
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
719
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
720
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
721
                # being chunked request keeps its rid but will get a new req_pool_idx
722
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
723
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
724

725
726
727
728
729
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
730

Lianmin Zheng's avatar
Lianmin Zheng committed
731
        # Run prefill first if possible
732
733
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
734
            return new_batch
735

736
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
737
        if self.running_batch is None:
738
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
739
        self.running_batch = self.update_running_batch(self.running_batch)
740
        return self.running_batch
741

Lianmin Zheng's avatar
Lianmin Zheng committed
742
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
743
        # Check if the grammar is ready in the grammar queue
744
        if self.grammar_queue:
745
            self.move_ready_grammar_requests()
746

Lianmin Zheng's avatar
Lianmin Zheng committed
747
748
749
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
750
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
753
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
754
        if running_bs >= self.max_running_requests:
755
            self.batch_is_full = True
756
757
758
759
760
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
761
        # Prefill policy
762
763
764
765
766
767
768
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
769
            running_bs if self.is_mixed_chunk else 0,
770
771
        )

772
773
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
774
            self.being_chunked_req.init_next_round_input()
775
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
776

Lianmin Zheng's avatar
Lianmin Zheng committed
777
        if self.lora_paths:
778
779
780
781
782
783
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

784
        # Get requests from the waiting queue to a new prefill batch
785
786
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
787
                self.lora_paths
788
789
790
791
792
793
794
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
795
                self.batch_is_full = True
796
797
                break

798
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
799
                self.batch_is_full = True
800
                break
801

802
803
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
804
805
806
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
807
808
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
809
        # Update waiting queue
810
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
811
812
813
814
815
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
816

817
        if adder.new_being_chunked_req is not None:
818
            assert self.being_chunked_req is None
819
            self.being_chunked_req = adder.new_being_chunked_req
820

821
822
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
823

824
825
        # Print stats
        if self.tp_rank == 0:
826
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
827

Lianmin Zheng's avatar
Lianmin Zheng committed
828
        # Create a new batch
829
830
831
832
833
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
834
            self.model_config,
835
            self.enable_overlap,
836
        )
837
        new_batch.prepare_for_extend()
838

Lianmin Zheng's avatar
Lianmin Zheng committed
839
        # Mixed-style chunked prefill
840
841
842
843
844
845
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
846
847
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
848
                self.running_batch.prepare_for_decode()
849
850
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
851
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
852
853
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
854
855
856

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
857
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
858
        """Update the current running decoding batch."""
859
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
860
861

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
862

863
864
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
865
866
            self.batch_is_full = False
            return None
867

Lianmin Zheng's avatar
Lianmin Zheng committed
868
        # Check if decode out of memory
869
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
870
871
872
873
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
874

Lianmin Zheng's avatar
Lianmin Zheng committed
875
876
877
878
879
880
881
882
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
883
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
884
885
886
887
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
888
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
889
890
891
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
892
893
894
895
896
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
897
898

        # Update batch tensors
899
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
900
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
901
902

    def run_batch(self, batch: ScheduleBatch):
903
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
904
905
        self.forward_ct += 1

906
        if self.is_generation:
907
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
908
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
909
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
910
                    model_worker_batch
911
                )
Ke Bao's avatar
Ke Bao committed
912
913
914
915
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
916
917
            else:
                logits_output = None
918
                if self.skip_tokenizer_init:
919
920
921
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
922
                else:
923
                    next_token_ids = torch.full((batch.batch_size(),), 0)
924
            batch.output_ids = next_token_ids
925
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
926
927
928
929
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
930
            ret = embeddings, model_worker_batch.bid
931
        return ret
Chayenne's avatar
Chayenne committed
932

Lianmin Zheng's avatar
Lianmin Zheng committed
933
934
935
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
936
937
            if batch.is_empty():
                self.running_batch = None
938
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
939
            self.process_batch_result_prefill(batch, result)
940
941
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
942
            self.current_stream.synchronize()
943
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
944
945

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
946
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
947

Lianmin Zheng's avatar
Lianmin Zheng committed
948
        if self.is_generation:
949
            logits_output, next_token_ids, bid = result
950
951

            if self.enable_overlap:
952
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
953
954
955
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
956
957
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
958
                            torch.arange(len(next_token_ids), device=self.device),
959
960
961
962
963
964
965
966
967
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
968
                next_token_ids = next_token_ids.tolist()
969
970
971

            # Check finish conditions
            logprob_pt = 0
972
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
973
974
975
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
976
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
977
978
979
980
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
981

982
                if req.is_being_chunked <= 0:
983
                    req.output_ids.append(next_token_id)
984
985
                    req.check_finished()

986
                    if req.finished():
987
                        self.tree_cache.cache_finished_req(req)
988
989
990
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

991
992
993
994
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
995
996
997

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
998
                        req.grammar.finished = req.finished()
999
                else:
1000
                    # being chunked reqs' prefill is not finished
1001
                    req.is_being_chunked -= 1
1002
1003
1004
1005
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1006

1007
1008
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1009
                self.current_stream.synchronize()
1010
1011
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1012
        else:  # embedding or reward model
1013
1014
            embeddings, bid = result
            embeddings = embeddings.tolist()
1015
1016
1017

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1018
1019
1020
                if req.is_retracted:
                    continue

1021
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1022
1023
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1024
1025
1026
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1027
1028
1029
1030
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1031
                else:
1032
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1033
                    req.is_being_chunked -= 1
1034

Lianmin Zheng's avatar
Lianmin Zheng committed
1035
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1036

Lianmin Zheng's avatar
Lianmin Zheng committed
1037
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1038
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1039
1040
        self.num_generated_tokens += len(batch.reqs)

1041
        if self.enable_overlap:
1042
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1043
            next_token_logprobs = logits_output.next_token_logprobs
1044
1045
1046
1047
1048
1049
1050
1051
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1052

1053
1054
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1055
1056
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1057
1058
1059
            if req.is_retracted:
                continue

1060
            if self.enable_overlap and req.finished():
1061
                # Free the one delayed token
1062
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1063
1064
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1065
1066
1067
1068
            req.output_ids.append(next_token_id)
            req.check_finished()

            if req.finished():
1069
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1070
1071

            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
1072
1073
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1074
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1075
1076
1077
1078
1079
1080
                    req.output_top_logprobs_val.append(
                        logits_output.output_top_logprobs_val[i]
                    )
                    req.output_top_logprobs_idx.append(
                        logits_output.output_top_logprobs_idx[i]
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1081

Lianmin Zheng's avatar
Lianmin Zheng committed
1082
1083
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1084
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1085

1086
1087
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1088
            self.current_stream.synchronize()
1089
1090
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1091
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1092

1093
1094
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1095
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1096
1097
1098
1099
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1100
            self.log_decode_stats()
1101

1102
1103
1104
1105
1106
1107
1108
1109
1110
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1111
1112
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1113
1114
1115
1116
1117
1118
1119

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

Lianmin Zheng's avatar
Lianmin Zheng committed
1120
1121
        if req.input_token_logprobs_val is None:
            input_token_logprobs_val = output.input_token_logprobs[
1122
1123
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1124
1125

            input_token_logprobs_idx = req.fill_ids[
1126
1127
1128
1129
1130
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1131
1132
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
            input_token_logprobs_idx = [
1134
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1135
                for x in input_token_logprobs_idx
1136
1137
            ]

1138
1139
1140
            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
Lianmin Zheng's avatar
Lianmin Zheng committed
1141
1142
1143
1144
1145
                input_token_logprobs_val = [None] + input_token_logprobs_val
                input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx

            req.input_token_logprobs_val = input_token_logprobs_val
            req.input_token_logprobs_idx = input_token_logprobs_idx
1146
1147
1148

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
            req.output_token_logprobs_val.extend(
                output.input_token_logprobs[
                    pt
                    + num_input_logprobs
                    - 1
                    - req.last_update_decode_tokens : pt
                    + num_input_logprobs
                    - 1
                ],
            )
            req.output_token_logprobs_idx.extend(
                req.fill_ids[
                    len(req.fill_ids)
                    - req.last_update_decode_tokens : len(req.fill_ids)
                ]
1164
1165
1166
            )

        if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1167
1168
1169
            if req.input_top_logprobs_val is None:
                req.input_top_logprobs_val = output.input_top_logprobs_val[i]
                req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1170
                if req.logprob_start_len == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1171
1172
                    req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
                    req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1173
1174

            if req.last_update_decode_tokens != 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1175
1176
                req.output_top_logprobs_val.extend(
                    output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1177
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1178
1179
1180
1181
1182
                req.output_top_logprobs_idx.extend(
                    output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
1183
1184
1185

        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1186
1187
1188
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1189
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1190
1191
1192
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1193
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
1194
            vids = []
1195
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1196
1197
            decode_ids_list = []
            read_offsets = []
1198
            output_ids = []
1199

Lianmin Zheng's avatar
Lianmin Zheng committed
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
                normalized_prompt_logprob = []
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    normalized_prompt_logprob
                ) = None

            for req in reqs:
                if req is skip_req:
                    continue
1229

Lianmin Zheng's avatar
Lianmin Zheng committed
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
                # TODO(lianmin): revisit this for overlap + retract + stream
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
                    or (not req.stream and len(req.output_ids) % 50 == 0)
                ):
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
                    vids.append(req.vid)
1243
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1244
1245
1246
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1247
1248
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1249
1250
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1251
1252
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
                        normalized_prompt_logprob.append(req.normalized_prompt_logprob)

            # Send to detokenizer
            if rids:
1272
                self.send_to_detokenizer.send_pyobj(
1273
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1274
1275
1276
                        rids,
                        finished_reasons,
                        vids,
1277
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1278
1279
                        decode_ids_list,
                        read_offsets,
1280
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
                        normalized_prompt_logprob,
1296
1297
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
                assert req.finished()
                rids.append(req.rid)
                finished_reasons.append(req.finished_reason.to_json())
                embeddings.append(req.embedding)
                prompt_tokens.append(len(req.origin_input_ids))
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1310

1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1390
    def flush_cache(self):
1391
        """Flush the memory pool and cache."""
1392
1393
1394
1395
1396
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1397
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1398
                self.grammar_backend.reset()
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1423
1424
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1425
1426
1427
1428

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1429
                if req.rid == recv_req.rid and not req.finished():
1430
1431
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1432
1433
                    break

Chayenne's avatar
Chayenne committed
1434
1435
1436
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1437
1438
1439
1440
1441
1442
1443
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1461
1462
1463
1464
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
        return parameter

1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
    def open_session(self, recv_req: OpenSessionReqInput) -> str:
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
        return session_id

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1498
1499
1500
1501
1502
1503

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1504
    dp_rank: Optional[int],
1505
    pipe_writer,
1506
):
1507
1508
    setproctitle.setproctitle("sglang::scheduler")

1509
1510
1511
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1512

1513
1514
1515
1516
1517
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1518
1519
1520
1521
    # set cpu affinity to this gpu process
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1522
    suppress_other_loggers()
1523
    parent_process = psutil.Process().parent()
1524
1525

    try:
1526
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1527
1528
1529
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1530
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1531
1532
1533
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1534
    except Exception:
1535
1536
1537
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)