scheduler.py 58.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""A scheduler that manages a tensor parallel GPU worker."""

import logging
17
import os
18
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
19
import threading
20
21
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from concurrent import futures
24
from types import SimpleNamespace
25
from typing import List, Optional
26

27
import psutil
28
import setproctitle
29
import torch
30
31
import zmq

32
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.configs.model_config import ModelConfig
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
40
    CloseSessionReqInput,
41
    FlushCacheReq,
42
43
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
44
45
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
46
47
    OpenSessionReqInput,
    OpenSessionReqOutput,
48
    ProfileReq,
49
50
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
51
52
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
53
54
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
55
56
57
58
59
60
61
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
62
    global_server_args_dict,
63
)
64
65
66
67
68
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
69
from sglang.srt.managers.session_controller import Session
70
from sglang.srt.managers.tp_worker import TpModelWorker
71
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
72
73
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
74
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
75
from sglang.srt.model_executor.forward_batch_info import ForwardMode
76
from sglang.srt.server_args import PortArgs, ServerArgs
77
78
79
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
80
    crash_on_warnings,
81
    get_bool_env_var,
82
    get_zmq_socket,
83
    set_gpu_proc_affinity,
84
85
86
    set_random_seed,
    suppress_other_loggers,
)
87
88
89
90
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

91
# Test retract decode
92
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
93

94
95
96
97
98
99
100
101
102
103

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
104
        dp_rank: Optional[int],
105
106
    ):
        # Parse args
107
        self.server_args = server_args
108
109
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
110
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
111
        self.disable_jump_forward = server_args.disable_jump_forward
112
113
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
114
        self.enable_overlap = not server_args.disable_overlap_schedule
115
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
116
        self.enable_metrics = server_args.enable_metrics
117
118
119
120

        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
121
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
122
123
124
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
125
126
127
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
128

129
130
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
131
132
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
133
134
135
                )
            else:
                # Send to the detokenizer
136
137
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
138
                )
139
        else:
140
            self.recv_from_tokenizer = None
141
142
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
143
144
145
146

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
147
            trust_remote_code=server_args.trust_remote_code,
148
            revision=server_args.revision,
149
            context_length=server_args.context_length,
150
151
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
152
153
            dtype=server_args.dtype,
            quantization=server_args.quantization,
154
        )
155
        self.is_generation = self.model_config.is_generation
156
157
158
159

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
160
            if self.model_config.is_multimodal:
161
162
163
164
165
166
167
168
169
170
171
172
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
173

174
175
176
177
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
178

179
180
181
182
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

183
184
        if self.enable_overlap:
            self.disable_jump_forward = True
185

186
        # Launch a tensor parallel worker
187
        if self.enable_overlap:
188
            TpWorkerClass = TpModelWorkerClient
189
190
        else:
            TpWorkerClass = TpModelWorker
191

192
        self.tp_worker = TpWorkerClass(
193
            server_args=server_args,
194
195
            gpu_id=gpu_id,
            tp_rank=tp_rank,
196
            dp_rank=dp_rank,
197
            nccl_port=port_args.nccl_port,
198
        )
199

200
        # Get token and memory info from the model worker
201
202
203
204
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
205
            self.max_req_len,
206
207
            self.max_req_input_len,
            self.random_seed,
208
            self.device,
209
210
211
212
213
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
214
215
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
216
        global_server_args_dict.update(worker_global_server_args_dict)
217
218
219
220
221
222
223
224
225
226
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

227
228
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
245
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
246
247
248

        # Init running status
        self.waiting_queue: List[Req] = []
249
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
250
        self.running_batch: Optional[ScheduleBatch] = None
251
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
252
        self.cur_batch: Optional[ScheduleBatch] = None
253
254
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
        self.forward_ct = 0
        self.forward_ct_decode = 0
257
        self.num_generated_tokens = 0
258
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
259
        self.stream_interval = server_args.stream_interval
260
261
262
263
        self.current_stream = torch.get_device_module(self.device).current_stream()

        # Session info
        self.sessions = {}
264
265
266

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
267
268
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
269
        self.being_chunked_req = None
270
271
272
273
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
274
        # Init the grammar backend for constrained generation
275
        self.grammar_queue: List[Req] = []
276
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279
280
281
282
283
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
284
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
301
302

        # Init new token estimation
303
304
305
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
306
307
308

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
309
310
            * server_args.schedule_conservativeness,
            1.0,
311
        )
312
313
314
315
316
317
318
319
320
321
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
324
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
325
        self.batch_is_full = False
326

Lianmin Zheng's avatar
Lianmin Zheng committed
327
328
329
330
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
331
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
332

333
        # Init profiler
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
349

350
        # Init metrics stats
351
352
353
354
355
356
357
358
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
359

Lianmin Zheng's avatar
Lianmin Zheng committed
360
    def watchdog_thread(self):
361
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

376
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
377

378
    @torch.no_grad()
379
    def event_loop_normal(self):
380
        """A normal scheduler loop."""
381
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
384

385
            batch = self.get_next_batch_to_run()
Ke Bao's avatar
Ke Bao committed
386
387
388
            if self.server_args.enable_dp_attention:
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
389
            self.cur_batch = batch
390
391
392
393

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
394
            else:
395
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
396
                self.check_memory()
397
                self.new_token_ratio = self.init_new_token_ratio
398
399

            self.last_batch = batch
400

401
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
402
    def event_loop_overlap(self):
403
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
408
409
410
411
412
413
414
415
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

416
417
418
419
420
421
422
423
424
425
                if self.last_batch is None:
                    # A dummy first batch to start the pipeline for overlap scheduler.
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
426
427
            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
428
429
430
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
431
432
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
433
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
434
                self.check_memory()
435
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
436
437
438

            self.last_batch = batch

Lianmin Zheng's avatar
Lianmin Zheng committed
439
    def recv_requests(self):
Ke Bao's avatar
Ke Bao committed
440
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
            recv_reqs = []

443
444
            if self.last_batch is None:
                recv_req = self.recv_from_tokenizer.recv_pyobj()
Lianmin Zheng's avatar
Lianmin Zheng committed
445
                recv_reqs.append(recv_req)
446
447
448
449
450
451
452
            else:
                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
453
454
        else:
            recv_reqs = None
455

Ke Bao's avatar
Ke Bao committed
456
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
457
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
458
459
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
460
    def process_input_requests(self, recv_reqs: List):
461
462
463
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
464
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
465
466
467
468
469
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
Chayenne's avatar
Chayenne committed
470
471
            elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
                success, message = self.update_weights_from_disk(recv_req)
472
                self.send_to_tokenizer.send_pyobj(
Chayenne's avatar
Chayenne committed
473
                    UpdateWeightFromDiskReqOutput(success, message)
474
                )
475
476
477
            elif isinstance(recv_req, GetWeightsByNameReqInput):
                parameter = self.get_weights_by_name(recv_req)
                self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
478
479
480
481
482
483
484
485
486
487
488
489
490
            elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
                success, message = self.init_weights_update_group(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    InitWeightsUpdateGroupReqOutput(success, message)
                )
            elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
                success, message = self.update_weights_from_distributed(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromDistributedReqOutput(success, message)
                )
            elif isinstance(recv_req, GetWeightsByNameReqInput):
                parameter = self.get_weights_by_name(recv_req)
                self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
491
492
493
494
495
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
496
497
498
499
500
            elif isinstance(recv_req, OpenSessionReqInput):
                session_id = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
501
502
503
504
505
506
507
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
508
        # Create a new request
509
        if recv_req.session_id is None or recv_req.session_id not in self.sessions:
510

Rin Intachuen's avatar
Rin Intachuen committed
511
512
513
514
515
516
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

517
518
519
520
521
522
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
523
                input_embeds=recv_req.input_embeds,
524
525
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
526

527
528
529
530
531
532
533
            if recv_req.session_id is not None:
                req.finished_reason = FINISH_ABORT(
                    f"Invalid request: session id {recv_req.session_id} does not exist"
                )
                self.waiting_queue.append(req)
                return
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
534
            # Create a new request from a previsou session
535
            session = self.sessions[recv_req.session_id]
536
            req = session.create_req(recv_req, self.tokenizer)
537
538
539
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
540

541
        # Handle image inputs
542
        if recv_req.image_inputs is not None:
543
544
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
545
            req.origin_input_ids = self.pad_input_ids_func(
546
                req.origin_input_ids, image_inputs
547
            )
548
            req.extend_image_inputs(image_inputs)
549

550
551
552
553
            if len(req.origin_input_ids) >= self.max_req_input_len:
                logger.error(
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
554
                )
555
                req.origin_input_ids = [0]
556
                req.image_inputs = None
557
                req.sampling_params.max_new_tokens = 0
558
559
560
                req.finished_reason = FINISH_ABORT(
                    "Multimodal prompt is too long. Check server logs for details."
                )
561
562
563
                self.waiting_queue.append(req)
                return

564
        # Copy more attributes
565
566
567
568
569
570
571
572
573
574
        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

        # Truncate prompts that are too long
575
        if len(req.origin_input_ids) > self.max_req_input_len:
576
577
578
579
580
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
581

582
583
584
585
586
587
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
588
            self.max_req_len - len(req.origin_input_ids) - 1,
589
590
        )

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
609
610
611
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
612
613
614

    def handle_embedding_request(
        self,
615
        recv_req: TokenizedEmbeddingReqInput,
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

635
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        if isinstance(self.tree_cache, RadixCache):
            self.tree_cache_metrics["total"] += (
                adder.log_input_tokens + adder.log_hit_tokens
            ) / 10**9
            self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
        else:
            tree_cache_hit_rate = 0.0

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
659
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
660
661
662
663
664
665
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
666
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
667
668
669
670
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
671
672
673
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
674
675
676
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
677
        self.num_generated_tokens = 0
678
        self.last_decode_stats_tic = time.time()
679
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
680
681
        logger.info(
            f"Decode batch. "
682
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
683
684
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
685
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
686
687
688
            f"#queue-req: {len(self.waiting_queue)}"
        )

689
690
691
692
693
694
695
696
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
697
698
699
700
701
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
702
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
703
                "KV cache pool leak detected!"
704
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
705
            )
706
707
708
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
709
710

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
711
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
712
                "Memory pool leak detected!"
713
714
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
715
            )
716
717
718
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
719

720
    def get_next_batch_to_run(self):
721
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
722
        if self.last_batch and self.last_batch.forward_mode.is_extend():
723
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
724
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
725
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
726
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
727
                # being chunked request keeps its rid but will get a new req_pool_idx
728
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
729
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
730

731
732
733
734
735
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
736

Lianmin Zheng's avatar
Lianmin Zheng committed
737
        # Run prefill first if possible
738
739
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
740
            return new_batch
741

742
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
743
        if self.running_batch is None:
744
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
745
        self.running_batch = self.update_running_batch(self.running_batch)
746
        return self.running_batch
747

Lianmin Zheng's avatar
Lianmin Zheng committed
748
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
749
        # Check if the grammar is ready in the grammar queue
750
        if self.grammar_queue:
751
            self.move_ready_grammar_requests()
752

Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
756
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
757
758
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
759
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
760
        if running_bs >= self.max_running_requests:
761
            self.batch_is_full = True
762
763
764
765
766
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
767
        # Prefill policy
768
769
770
771
772
773
774
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
775
            running_bs if self.is_mixed_chunk else 0,
776
777
        )

778
779
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
780
            self.being_chunked_req.init_next_round_input()
781
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
782

Lianmin Zheng's avatar
Lianmin Zheng committed
783
        if self.lora_paths:
784
785
786
787
788
789
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

790
        # Get requests from the waiting queue to a new prefill batch
791
792
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
793
                self.lora_paths
794
795
796
797
798
799
800
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
801
                self.batch_is_full = True
802
803
                break

804
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
805
                self.batch_is_full = True
806
                break
807

808
809
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
810
811
812
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
813
814
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
815
        # Update waiting queue
816
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
817
818
819
820
821
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
822

823
        if adder.new_being_chunked_req is not None:
824
            assert self.being_chunked_req is None
825
            self.being_chunked_req = adder.new_being_chunked_req
826

827
828
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
829

830
831
        # Print stats
        if self.tp_rank == 0:
832
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
833

Lianmin Zheng's avatar
Lianmin Zheng committed
834
        # Create a new batch
835
836
837
838
839
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
840
            self.model_config,
841
            self.enable_overlap,
842
        )
843
        new_batch.prepare_for_extend()
844

Lianmin Zheng's avatar
Lianmin Zheng committed
845
        # Mixed-style chunked prefill
846
847
848
849
850
851
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
852
853
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
854
                self.running_batch.prepare_for_decode()
855
856
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
857
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
858
859
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
860
861
862

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
863
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
864
        """Update the current running decoding batch."""
865
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
866
867

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
868

869
870
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
871
872
            self.batch_is_full = False
            return None
873

Lianmin Zheng's avatar
Lianmin Zheng committed
874
        # Check if decode out of memory
875
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
876
877
878
879
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
880

Lianmin Zheng's avatar
Lianmin Zheng committed
881
882
883
884
885
886
887
888
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
889
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
890
891
892
893
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
894
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
895
896
897
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
898
899
900
901
902
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
903
904

        # Update batch tensors
905
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
906
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
907
908

    def run_batch(self, batch: ScheduleBatch):
909
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
910
911
        self.forward_ct += 1

912
        if self.is_generation:
913
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
914
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
915
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
916
                    model_worker_batch
917
                )
Ke Bao's avatar
Ke Bao committed
918
919
920
921
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
            else:
                logits_output = None
924
                if self.skip_tokenizer_init:
925
926
927
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
928
                else:
929
                    next_token_ids = torch.full((batch.batch_size(),), 0)
930
            batch.output_ids = next_token_ids
931
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
932
933
934
935
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
936
            ret = embeddings, model_worker_batch.bid
937
        return ret
Chayenne's avatar
Chayenne committed
938

Lianmin Zheng's avatar
Lianmin Zheng committed
939
940
941
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
942
943
            if batch.is_empty():
                self.running_batch = None
944
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
945
            self.process_batch_result_prefill(batch, result)
946
947
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
948
            self.current_stream.synchronize()
949
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
950
951

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
952

Lianmin Zheng's avatar
Lianmin Zheng committed
953
        if self.is_generation:
954
            logits_output, next_token_ids, bid = result
955
956

            if self.enable_overlap:
957
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
958
959
960
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
961
962
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
963
                            torch.arange(len(next_token_ids), device=self.device),
964
965
966
967
968
969
970
971
972
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
973
                next_token_ids = next_token_ids.tolist()
974
975
976

            # Check finish conditions
            logprob_pt = 0
977
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
978
979
980
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
981
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
982
983
984
985
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
986

987
                if req.is_being_chunked <= 0:
988
                    req.completion_tokens_wo_jump_forward += 1
989
                    req.output_ids.append(next_token_id)
990
991
                    req.check_finished()

992
                    if req.finished():
993
                        self.tree_cache.cache_finished_req(req)
994
995
996
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

997
998
999
1000
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002
1003

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1004
                        req.grammar.finished = req.finished()
1005
                else:
1006
                    # being chunked reqs' prefill is not finished
1007
1008
                    req.is_being_chunked -= 1

1009
1010
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1011
                self.current_stream.synchronize()
1012
1013
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1014
        else:  # embedding or reward model
1015
1016
            embeddings, bid = result
            embeddings = embeddings.tolist()
1017
1018
1019

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1020
1021
1022
                if req.is_retracted:
                    continue

1023
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1024
1025
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1026
1027
1028
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1029
1030
1031
1032
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1033
                else:
1034
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1035
                    req.is_being_chunked -= 1
1036

1037
        self.stream_output(batch.reqs)
1038

Lianmin Zheng's avatar
Lianmin Zheng committed
1039
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1040
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1041
1042
        self.num_generated_tokens += len(batch.reqs)

1043
        if self.enable_overlap:
1044
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1045
            next_token_logprobs = logits_output.next_token_logprobs
1046
1047
1048
1049
1050
1051
1052
1053
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1054

1055
1056
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1057
1058
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1059
1060
1061
            if req.is_retracted:
                continue

1062
            if self.enable_overlap and req.finished():
1063
                # Free the one delayed token
1064
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1065
1066
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1067
1068
1069
1070
1071
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

            if req.finished():
1072
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1073
1074
1075
1076
1077
1078
1079
1080

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

Lianmin Zheng's avatar
Lianmin Zheng committed
1081
1082
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1083
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1084

1085
1086
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1087
            self.current_stream.synchronize()
1088
1089
            batch.next_batch_sampling_info.sampling_info_done.set()

1090
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1091

1092
1093
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1094
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1095
1096
1097
1098
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1099
            self.log_decode_stats()
1100

1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1130
1131
1132
1133
1134
1135
1136
1137

            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
            input_token_ids = [
                x if x < self.model_config.vocab_size - 1 else 0
                for x in input_token_ids
            ]

1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

1182
    def stream_output(self, reqs: List[Req]):
1183
        """Stream the output to detokenizer."""
1184
        output_rids = []
1185
        output_meta_info: List[dict] = []
1186
1187
1188
1189
1190
1191
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1192
            output_ids = []
1193
1194
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1195
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1196
        else:  # embedding or reward model
1197
1198
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1199
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1200

1201
        for req in reqs:
1202
            # TODO(lianmin): revisit this for overlap + retract + stream
1203
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1204
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1205
1206
1207
1208
1209
1210
1211
1212
1213
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1214
1215
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1216
1217
1218
1219
1220
1221
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1222
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1223
1224
1225
1226
1227

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1228
                        "cached_tokens": req.cached_tokens,
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
                else:  # embedding or reward model
1251
1252
1253
1254
1255
1256
1257
1258
1259
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1260
                self.send_to_detokenizer.send_pyobj(
1261
1262
1263
1264
1265
1266
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1267
                        output_ids,
1268
1269
1270
1271
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1272
                        output_no_stop_trim,
1273
1274
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1275
            else:  # embedding or reward model
1276
                self.send_to_detokenizer.send_pyobj(
1277
1278
1279
1280
1281
1282
1283
1284
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1364
    def flush_cache(self):
1365
        """Flush the memory pool and cache."""
1366
1367
1368
1369
1370
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1371
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
                self.grammar_backend.reset()
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1397
1398
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1399
1400
1401
1402

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1403
                if req.rid == recv_req.rid and not req.finished():
1404
1405
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1406
1407
                    break

Chayenne's avatar
Chayenne committed
1408
1409
1410
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1411
1412
1413
1414
1415
1416
1417
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1435
1436
1437
1438
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
        return parameter

1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
    def open_session(self, recv_req: OpenSessionReqInput) -> str:
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
        return session_id

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1472
1473
1474
1475
1476
1477

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1478
    dp_rank: Optional[int],
1479
    pipe_writer,
1480
):
1481
1482
    setproctitle.setproctitle("sglang::scheduler")

1483
1484
1485
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1486

1487
1488
1489
1490
1491
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1492
1493
1494
1495
    # set cpu affinity to this gpu process
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1496
    suppress_other_loggers()
1497
    parent_process = psutil.Process().parent()
1498
1499

    try:
1500
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1501
1502
1503
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1504
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1505
1506
1507
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1508
    except Exception:
1509
1510
1511
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)