scheduler.py 62.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""A scheduler that manages a tensor parallel GPU worker."""

import logging
17
import os
18
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
19
import threading
20
21
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from concurrent import futures
24
from types import SimpleNamespace
25
from typing import Dict, List, Optional, Tuple
26

27
import psutil
28
import setproctitle
29
import torch
30
31
import zmq

32
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.configs.model_config import ModelConfig
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
40
    CloseSessionReqInput,
41
    FlushCacheReq,
42
43
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
44
45
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
46
47
    OpenSessionReqInput,
    OpenSessionReqOutput,
48
    ProfileReq,
49
50
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
51
52
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
53
54
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
55
56
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
57
58
59
60
61
62
63
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
64
    global_server_args_dict,
65
)
66
67
68
69
70
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
71
from sglang.srt.managers.session_controller import Session
72
from sglang.srt.managers.tp_worker import TpModelWorker
73
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
74
75
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
76
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
77
from sglang.srt.model_executor.forward_batch_info import ForwardMode
78
from sglang.srt.server_args import PortArgs, ServerArgs
79
80
81
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
82
    crash_on_warnings,
83
    get_bool_env_var,
84
    get_zmq_socket,
85
    set_gpu_proc_affinity,
86
87
88
    set_random_seed,
    suppress_other_loggers,
)
89
90
91
92
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

93
# Test retract decode for debugging purposes
94
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
95

96
97
98
99
100
101
102
103
104
105

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
106
        dp_rank: Optional[int],
107
108
    ):
        # Parse args
109
        self.server_args = server_args
110
111
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
112
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        self.disable_jump_forward = server_args.disable_jump_forward
114
115
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
116
        self.enable_overlap = not server_args.disable_overlap_schedule
117
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
118
        self.enable_metrics = server_args.enable_metrics
119
120
121
122

        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
123
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
124
125
126
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
127
128
129
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
130

131
            if server_args.skip_tokenizer_init:
132
                # Directly send to the TokenizerManager
133
134
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
135
136
                )
            else:
137
                # Send to the DetokenizerManager
138
139
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
140
                )
141
        else:
142
            self.recv_from_tokenizer = None
143
144
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
145
146
147
148

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
149
            trust_remote_code=server_args.trust_remote_code,
150
            revision=server_args.revision,
151
            context_length=server_args.context_length,
152
153
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
154
155
            dtype=server_args.dtype,
            quantization=server_args.quantization,
156
        )
157
        self.is_generation = self.model_config.is_generation
158
159
160
161

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
162
            if self.model_config.is_multimodal:
163
164
165
166
167
168
169
170
171
172
173
174
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
175

176
177
178
179
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
180

181
182
183
184
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

185
186
        if self.enable_overlap:
            self.disable_jump_forward = True
187

188
        # Launch a tensor parallel worker
189
        if self.enable_overlap:
190
            TpWorkerClass = TpModelWorkerClient
191
192
        else:
            TpWorkerClass = TpModelWorker
193

194
        self.tp_worker = TpWorkerClass(
195
            server_args=server_args,
196
197
            gpu_id=gpu_id,
            tp_rank=tp_rank,
198
            dp_rank=dp_rank,
199
            nccl_port=port_args.nccl_port,
200
        )
201

202
        # Get token and memory info from the model worker
203
204
205
206
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
207
            self.max_req_len,
208
209
            self.max_req_input_len,
            self.random_seed,
210
            self.device,
211
212
213
214
215
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
216
217
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
218
        global_server_args_dict.update(worker_global_server_args_dict)
219
220
221
222
223
224
225
226
227
228
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

229
230
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
247
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
248
249
250

        # Init running status
        self.waiting_queue: List[Req] = []
251
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
252
        self.running_batch: Optional[ScheduleBatch] = None
253
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
254
        self.cur_batch: Optional[ScheduleBatch] = None
255
256
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
257
258
        self.forward_ct = 0
        self.forward_ct_decode = 0
259
        self.num_generated_tokens = 0
260
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
261
        self.stream_interval = server_args.stream_interval
262
263
264
        self.current_stream = torch.get_device_module(self.device).current_stream()

        # Session info
265
        self.sessions: Dict[str, Session] = {}
266
267
268

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
269
270
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
271
        self.being_chunked_req = None
272
273
274
275
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
276
        # Init the grammar backend for constrained generation
277
        self.grammar_queue: List[Req] = []
278
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
279
280
281
282
283
284
285
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
286
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
303
304

        # Init new token estimation
305
306
307
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
308
309
310

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
311
312
            * server_args.schedule_conservativeness,
            1.0,
313
        )
314
315
316
317
318
319
320
321
322
323
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
324
325
326
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
327
        self.batch_is_full = False
328

Lianmin Zheng's avatar
Lianmin Zheng committed
329
330
331
332
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
333
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
334

335
        # Init profiler
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
351

352
        # Init metrics stats
353
354
355
356
357
358
359
360
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
361

Lianmin Zheng's avatar
Lianmin Zheng committed
362
    def watchdog_thread(self):
363
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

378
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
379

380
    @torch.no_grad()
381
    def event_loop_normal(self):
382
        """A normal scheduler loop."""
383
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
386

387
            batch = self.get_next_batch_to_run()
388
389

            if self.server_args.enable_dp_attention:  # TODO: simplify this
Ke Bao's avatar
Ke Bao committed
390
391
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
392
            self.cur_batch = batch
393
394
395
396

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
397
            else:
398
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
399
                self.check_memory()
400
                self.new_token_ratio = self.init_new_token_ratio
401
402

            self.last_batch = batch
403

404
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
405
    def event_loop_overlap(self):
406
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
409
410
411
412
413
414
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
415

Lianmin Zheng's avatar
Lianmin Zheng committed
416
417
418
419
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

420
                if self.last_batch is None:
421
                    # Create a dummy first batch to start the pipeline for overlap scheduler.
422
423
424
425
426
427
428
429
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
430
            if self.last_batch:
431
                # Process the results of the last batch
Lianmin Zheng's avatar
Lianmin Zheng committed
432
                tmp_batch, tmp_result = result_queue.popleft()
433
434
435
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
436
437
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
438
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
439
                self.check_memory()
440
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
443

            self.last_batch = batch

444
445
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
Ke Bao's avatar
Ke Bao committed
446
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
447
448
            recv_reqs = []

449
450
451
452
453
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
454
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
        else:
            recv_reqs = None
457

Ke Bao's avatar
Ke Bao committed
458
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
459
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
460
461
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
462
    def process_input_requests(self, recv_reqs: List):
463
464
465
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
466
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
467
468
469
470
471
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
Chayenne's avatar
Chayenne committed
472
473
            elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
                success, message = self.update_weights_from_disk(recv_req)
474
                self.send_to_tokenizer.send_pyobj(
Chayenne's avatar
Chayenne committed
475
                    UpdateWeightFromDiskReqOutput(success, message)
476
                )
477
478
479
480
481
482
483
484
485
486
            elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
                success, message = self.init_weights_update_group(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    InitWeightsUpdateGroupReqOutput(success, message)
                )
            elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
                success, message = self.update_weights_from_distributed(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromDistributedReqOutput(success, message)
                )
487
488
489
490
491
            elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
                success, message = self.update_weights_from_tensor(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromTensorReqOutput(success, message)
                )
492
493
494
            elif isinstance(recv_req, GetWeightsByNameReqInput):
                parameter = self.get_weights_by_name(recv_req)
                self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
495
496
497
498
499
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
500
            elif isinstance(recv_req, OpenSessionReqInput):
501
502
503
504
                session_id, success = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    OpenSessionReqOutput(session_id=session_id, success=success)
                )
505
506
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
507
508
509
510
511
512
513
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
514
        # Create a new request
515
516
517
518
519
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
520

Rin Intachuen's avatar
Rin Intachuen committed
521
522
523
524
525
526
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

527
528
529
530
531
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
532
533
534
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
                stream=recv_req.stream,
535
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
536
                input_embeds=recv_req.input_embeds,
537
                eos_token_ids=self.model_config.hf_eos_token_id,
538
539
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
540

541
542
543
544
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
545
                req.finished_reason = FINISH_ABORT(
546
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
547
548
549
550
                )
                self.waiting_queue.append(req)
                return
        else:
551
552
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
553
            req = session.create_req(recv_req, self.tokenizer)
554
555
556
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
557

558
        # Handle image inputs
559
        if recv_req.image_inputs is not None:
560
561
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
562
            req.origin_input_ids = self.pad_input_ids_func(
563
                req.origin_input_ids, image_inputs
564
            )
565
            req.extend_image_inputs(image_inputs)
566

567
568
569
570
            if len(req.origin_input_ids) >= self.max_req_input_len:
                logger.error(
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
571
                )
572
                req.origin_input_ids = [0]
573
                req.image_inputs = None
574
                req.sampling_params.max_new_tokens = 0
575
576
577
                req.finished_reason = FINISH_ABORT(
                    "Multimodal prompt is too long. Check server logs for details."
                )
578
579
580
                self.waiting_queue.append(req)
                return

581
        # Copy more attributes
582
583
584
585
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
586
            req.logprob_start_len = len(req.origin_input_ids) - 1
587
588

        # Truncate prompts that are too long
589
        if len(req.origin_input_ids) > self.max_req_input_len:
590
591
592
593
594
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
595

596
597
598
599
600
601
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
602
            self.max_req_len - len(req.origin_input_ids) - 1,
603
604
        )

605
606
607
608
609
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
610
            or req.sampling_params.ebnf is not None
611
612
613
614
615
616
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
617
618
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
619
620
621
622
623
624
625

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
626
627
628
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
629
630
631

    def handle_embedding_request(
        self,
632
        recv_req: TokenizedEmbeddingReqInput,
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

652
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
653
654
655
656
657
658
659
        self.tree_cache_metrics["total"] += (
            adder.log_input_tokens + adder.log_hit_tokens
        ) / 10**9
        self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
        tree_cache_hit_rate = (
            self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
        )
660
661
662
663
664
665
666
667
668
669
670
671
672

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
673
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
674
675
676
677
678
679
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
680
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
681
682
683
684
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
687
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
688
689
690
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
691
        self.num_generated_tokens = 0
692
        self.last_decode_stats_tic = time.time()
693
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
694
695
        logger.info(
            f"Decode batch. "
696
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
697
698
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
699
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
700
701
702
            f"#queue-req: {len(self.waiting_queue)}"
        )

703
704
705
706
707
708
709
710
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
711
712
713
714
715
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
716
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
717
                "KV cache pool leak detected!"
718
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
719
            )
720
721
722
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
723
724

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
725
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
726
                "Memory pool leak detected!"
727
728
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
729
            )
730
731
732
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
733

734
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
735
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
736
        if self.last_batch and self.last_batch.forward_mode.is_extend():
737
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
738
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
739
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
740
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
741
                # being chunked request keeps its rid but will get a new req_pool_idx
742
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
743
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
744

745
746
747
748
749
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
750

Lianmin Zheng's avatar
Lianmin Zheng committed
751
        # Run prefill first if possible
752
753
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
754
            return new_batch
755

756
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
757
        if self.running_batch is None:
758
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
759
        self.running_batch = self.update_running_batch(self.running_batch)
760
        return self.running_batch
761

Lianmin Zheng's avatar
Lianmin Zheng committed
762
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
763
        # Check if the grammar is ready in the grammar queue
764
        if self.grammar_queue:
765
            self.move_ready_grammar_requests()
766

Lianmin Zheng's avatar
Lianmin Zheng committed
767
768
769
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
770
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
771
772
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
773
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
774
        if running_bs >= self.max_running_requests:
775
            self.batch_is_full = True
776
777
778
779
780
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
781
        # Prefill policy
782
783
784
785
786
787
788
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
789
            running_bs if self.is_mixed_chunk else 0,
790
791
        )

792
793
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
794
            self.being_chunked_req.init_next_round_input()
795
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
796

Lianmin Zheng's avatar
Lianmin Zheng committed
797
        if self.lora_paths:
798
799
800
801
802
803
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

804
        # Get requests from the waiting queue to a new prefill batch
805
806
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
807
                self.lora_paths
808
809
810
811
812
813
814
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
815
                self.batch_is_full = True
816
817
                break

818
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
819
                self.batch_is_full = True
820
                break
821

822
823
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
824
825
826
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
827
                break
828
829
            if self.server_args.prefill_only_one_req:
                break
830

Lianmin Zheng's avatar
Lianmin Zheng committed
831
        # Update waiting queue
832
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
833
834
835
836
837
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
838

839
        if adder.new_being_chunked_req is not None:
840
            assert self.being_chunked_req is None
841
            self.being_chunked_req = adder.new_being_chunked_req
842

843
844
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
845

846
847
        # Print stats
        if self.tp_rank == 0:
848
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
849

Lianmin Zheng's avatar
Lianmin Zheng committed
850
        # Create a new batch
851
852
853
854
855
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
856
            self.model_config,
857
            self.enable_overlap,
858
        )
859
        new_batch.prepare_for_extend()
860

Lianmin Zheng's avatar
Lianmin Zheng committed
861
        # Mixed-style chunked prefill
862
863
864
865
866
867
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
868
869
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
870
                self.running_batch.prepare_for_decode()
871
872
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
873
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
874
875
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
876
877
878

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
879
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
880
        """Update the current running decoding batch."""
881
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
882
883

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
884

885
886
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
887
888
            self.batch_is_full = False
            return None
889

Lianmin Zheng's avatar
Lianmin Zheng committed
890
        # Check if decode out of memory
891
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
892
893
894
895
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
896

Lianmin Zheng's avatar
Lianmin Zheng committed
897
898
899
900
901
902
903
904
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
905
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907
908
909
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
910
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
911
912
913
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
914
915
916
917
918
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
919
920

        # Update batch tensors
921
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
922
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
923
924

    def run_batch(self, batch: ScheduleBatch):
925
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
926
927
        self.forward_ct += 1

928
        if self.is_generation:
929
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
930
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
931
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
932
                    model_worker_batch
933
                )
Ke Bao's avatar
Ke Bao committed
934
935
936
937
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
938
939
            else:
                logits_output = None
940
                if self.skip_tokenizer_init:
941
942
943
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
944
                else:
945
                    next_token_ids = torch.full((batch.batch_size(),), 0)
946
            batch.output_ids = next_token_ids
947
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
948
949
950
951
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
952
            ret = embeddings, model_worker_batch.bid
953
        return ret
Chayenne's avatar
Chayenne committed
954

Lianmin Zheng's avatar
Lianmin Zheng committed
955
956
957
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
958
959
            if batch.is_empty():
                self.running_batch = None
960
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
961
            self.process_batch_result_prefill(batch, result)
962
963
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
964
            self.current_stream.synchronize()
965
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
966
967

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
968
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
969

Lianmin Zheng's avatar
Lianmin Zheng committed
970
        if self.is_generation:
971
            logits_output, next_token_ids, bid = result
972
973

            if self.enable_overlap:
974
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
975
976
            else:
                # Move next_token_ids and logprobs to cpu
977
                next_token_ids = next_token_ids.tolist()
978
                if batch.return_logprob:
979
                    logits_output.next_token_logprobs = (
980
                        logits_output.next_token_logprobs.tolist()
981
982
983
984
985
986
987
988
989
990
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )

            # Check finish conditions
            logprob_pt = 0
991
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
992
993
994
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
995
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
996
997
998
999
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1000

1001
                if req.is_being_chunked <= 0:
1002
                    req.output_ids.append(next_token_id)
1003
1004
                    req.check_finished()

1005
                    if req.finished():
1006
                        self.tree_cache.cache_finished_req(req)
1007
1008
1009
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1010
1011
1012
1013
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1014
1015
1016

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1017
                        req.grammar.finished = req.finished()
1018
                else:
1019
                    # being chunked reqs' prefill is not finished
1020
                    req.is_being_chunked -= 1
1021
1022
1023
1024
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1025

1026
1027
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1028
                self.current_stream.synchronize()
1029
1030
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1031
        else:  # embedding or reward model
1032
1033
            embeddings, bid = result
            embeddings = embeddings.tolist()
1034
1035
1036

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1037
1038
1039
                if req.is_retracted:
                    continue

1040
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1041
1042
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1043
1044
1045
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1046
1047
1048
1049
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1050
                else:
1051
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1052
                    req.is_being_chunked -= 1
1053

Lianmin Zheng's avatar
Lianmin Zheng committed
1054
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1055

Lianmin Zheng's avatar
Lianmin Zheng committed
1056
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1057
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1058
1059
        self.num_generated_tokens += len(batch.reqs)

1060
        if self.enable_overlap:
1061
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1062
            next_token_logprobs = logits_output.next_token_logprobs
1063
1064
        else:
            next_token_ids = next_token_ids.tolist()
1065
1066
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1067

1068
1069
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1070
1071
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1072
1073
1074
            if req.is_retracted:
                continue

1075
            if self.enable_overlap and req.finished():
1076
                # Free the one delayed token
1077
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1078
1079
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1080
1081
1082
1083
            req.output_ids.append(next_token_id)
            req.check_finished()

            if req.finished():
1084
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1085
1086

            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
1087
1088
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1089
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1090
                    req.output_top_logprobs_val.append(
1091
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
1093
                    )
                    req.output_top_logprobs_idx.append(
1094
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1096

Lianmin Zheng's avatar
Lianmin Zheng committed
1097
1098
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1099
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1100

1101
1102
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1103
            self.current_stream.synchronize()
1104
1105
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1106
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1107

1108
1109
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1110
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1111
1112
1113
1114
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1115
            self.log_decode_stats()
1116

1117
1118
1119
1120
1121
1122
1123
1124
1125
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
1127
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1128
1129
1130
1131
1132
1133
1134

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

Lianmin Zheng's avatar
Lianmin Zheng committed
1135
1136
        if req.input_token_logprobs_val is None:
            input_token_logprobs_val = output.input_token_logprobs[
1137
1138
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
1140

            input_token_logprobs_idx = req.fill_ids[
1141
1142
1143
1144
1145
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1146
1147
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1148
            input_token_logprobs_idx = [
1149
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1150
                for x in input_token_logprobs_idx
1151
1152
            ]

1153
1154
1155
            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
Lianmin Zheng's avatar
Lianmin Zheng committed
1156
1157
1158
1159
1160
                input_token_logprobs_val = [None] + input_token_logprobs_val
                input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx

            req.input_token_logprobs_val = input_token_logprobs_val
            req.input_token_logprobs_idx = input_token_logprobs_idx
1161
1162
1163

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
            req.output_token_logprobs_val.extend(
                output.input_token_logprobs[
                    pt
                    + num_input_logprobs
                    - 1
                    - req.last_update_decode_tokens : pt
                    + num_input_logprobs
                    - 1
                ],
            )
            req.output_token_logprobs_idx.extend(
                req.fill_ids[
                    len(req.fill_ids)
                    - req.last_update_decode_tokens : len(req.fill_ids)
                ]
1179
1180
1181
            )

        if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1182
1183
1184
            if req.input_top_logprobs_val is None:
                req.input_top_logprobs_val = output.input_top_logprobs_val[i]
                req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1185
                if req.logprob_start_len == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1186
1187
                    req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
                    req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1188
1189

            if req.last_update_decode_tokens != 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1190
1191
                req.output_top_logprobs_val.extend(
                    output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1192
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1193
1194
1195
                req.output_top_logprobs_idx.extend(
                    output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
                )
1196
1197
1198

            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1199
1200
1201

        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1202
1203
1204
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1205
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
1207
1208
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1209
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
1210
            vids = []
1211
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1212
1213
            decode_ids_list = []
            read_offsets = []
1214
            output_ids = []
1215
            origin_input_ids = []
1216

Lianmin Zheng's avatar
Lianmin Zheng committed
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
                normalized_prompt_logprob = []
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    normalized_prompt_logprob
                ) = None

            for req in reqs:
                if req is skip_req:
                    continue
1246

Lianmin Zheng's avatar
Lianmin Zheng committed
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
                # TODO(lianmin): revisit this for overlap + retract + stream
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
                    or (not req.stream and len(req.output_ids) % 50 == 0)
                ):
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
                    vids.append(req.vid)
1260
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
1262
1263
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1264
                    if self.skip_tokenizer_init or self.server_args.return_token_ids:
1265
                        output_ids.append(req.output_ids)
1266
1267
1268
1269
1270
1271
                    else:
                        output_ids = None
                    if self.server_args.return_token_ids:
                        origin_input_ids.append(req.origin_input_ids)
                    else:
                        origin_input_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1272
1273
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1274
1275
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
                        normalized_prompt_logprob.append(req.normalized_prompt_logprob)

            # Send to detokenizer
            if rids:
1295
                self.send_to_detokenizer.send_pyobj(
1296
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1297
1298
1299
                        rids,
                        finished_reasons,
                        vids,
1300
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
1302
                        decode_ids_list,
                        read_offsets,
1303
                        origin_input_ids,
1304
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
                        normalized_prompt_logprob,
1320
1321
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
                assert req.finished()
                rids.append(req.rid)
                finished_reasons.append(req.finished_reason.to_json())
                embeddings.append(req.embedding)
                prompt_tokens.append(len(req.origin_input_ids))
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1334

1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1414
    def flush_cache(self):
1415
        """Flush the memory pool and cache."""
1416
1417
1418
1419
1420
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1421
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1422
                self.grammar_backend.reset()
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1447
1448
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1449
1450
1451
1452

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1453
                if req.rid == recv_req.rid and not req.finished():
1454
1455
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1456
1457
                    break

Chayenne's avatar
Chayenne committed
1458
1459
1460
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1461
1462
1463
1464
1465
1466
1467
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1496
1497
1498
1499
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
        return parameter

1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1514
    def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1515
1516
1517
1518
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1519
1520
1521
1522
            return session_id, False
        elif session_id is None:
            logger.warning(f"session id is None, cannot open.")
            return session_id, False
1523
1524
1525
1526
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1527
            return session_id, True
1528
1529
1530
1531
1532
1533
1534
1535
1536

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1537
1538
1539
1540
1541
1542

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1543
    dp_rank: Optional[int],
1544
    pipe_writer,
1545
):
1546
1547
    setproctitle.setproctitle("sglang::scheduler")

1548
1549
1550
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1551

1552
    # Configue the logger
1553
1554
1555
1556
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1557
    suppress_other_loggers()
1558

1559
    # Set cpu affinity to this gpu process
1560
1561
1562
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1563
    parent_process = psutil.Process().parent()
1564

1565
    # Create a scheduler and run the event loop
1566
    try:
1567
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1568
1569
1570
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1571
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1572
1573
1574
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1575
    except Exception:
1576
1577
1578
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)