scheduler.py 63.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""A scheduler that manages a tensor parallel GPU worker."""

import logging
17
import os
18
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
19
import threading
20
21
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from concurrent import futures
24
from types import SimpleNamespace
25
from typing import Dict, List, Optional, Tuple
26

27
import psutil
28
import setproctitle
29
import torch
30
31
import zmq

32
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.configs.model_config import ModelConfig
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
40
    CloseSessionReqInput,
41
    FlushCacheReq,
42
43
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
44
45
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
46
47
    OpenSessionReqInput,
    OpenSessionReqOutput,
48
    ProfileReq,
49
50
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
51
52
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
53
54
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
55
56
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
57
58
59
60
61
62
63
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
64
    global_server_args_dict,
65
)
66
67
68
69
70
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
71
from sglang.srt.managers.session_controller import Session
72
from sglang.srt.managers.tp_worker import TpModelWorker
73
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
74
75
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
76
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
77
from sglang.srt.model_executor.forward_batch_info import ForwardMode
78
from sglang.srt.server_args import PortArgs, ServerArgs
79
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
80
81
82
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
83
    crash_on_warnings,
84
    get_bool_env_var,
85
    get_zmq_socket,
86
    set_gpu_proc_affinity,
87
88
89
    set_random_seed,
    suppress_other_loggers,
)
90
91
92
93
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

94
# Test retract decode for debugging purposes
95
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
96

97
98
99
100
101
102
103
104
105
106

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
107
        dp_rank: Optional[int],
108
109
    ):
        # Parse args
110
        self.server_args = server_args
111
112
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
113
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        self.disable_jump_forward = server_args.disable_jump_forward
115
116
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
117
        self.enable_overlap = not server_args.disable_overlap_schedule
118
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
119
        self.enable_metrics = server_args.enable_metrics
120
121
122
123
124
125
126
127
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
        self.decode_mem_cache_buf_multiplier = (
            self.server_args.speculative_num_draft_tokens
            if not self.spec_algorithm.is_none()
            else 1
        )
128
129
130
131

        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
132
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
133
134
135
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
136
137
138
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
139

140
            if server_args.skip_tokenizer_init:
141
                # Directly send to the TokenizerManager
142
143
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
144
145
                )
            else:
146
                # Send to the DetokenizerManager
147
148
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
149
                )
150
        else:
151
            self.recv_from_tokenizer = None
152
153
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
154
155
156
157

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
158
            trust_remote_code=server_args.trust_remote_code,
159
            revision=server_args.revision,
160
            context_length=server_args.context_length,
161
162
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
163
164
            dtype=server_args.dtype,
            quantization=server_args.quantization,
165
        )
166
        self.is_generation = self.model_config.is_generation
167
168
169
170

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
171
            if self.model_config.is_multimodal:
172
173
174
175
176
177
178
179
180
181
182
183
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
184

185
186
187
188
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
189

190
191
192
193
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

194
195
        if self.enable_overlap:
            self.disable_jump_forward = True
196

197
        # Launch a tensor parallel worker
198
        if self.enable_overlap:
199
            TpWorkerClass = TpModelWorkerClient
200
201
        else:
            TpWorkerClass = TpModelWorker
202

203
        self.tp_worker = TpWorkerClass(
204
            server_args=server_args,
205
206
            gpu_id=gpu_id,
            tp_rank=tp_rank,
207
            dp_rank=dp_rank,
208
            nccl_port=port_args.nccl_port,
209
        )
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        # Launch worker for speculative decoding if need
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

226
        # Get token and memory info from the model worker
227
228
229
230
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
231
            self.max_req_len,
232
233
            self.max_req_input_len,
            self.random_seed,
234
            self.device,
235
236
237
238
239
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
240
241
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
242
        global_server_args_dict.update(worker_global_server_args_dict)
243
244
245
246
247
248
249
250
251
252
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

253
254
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
271
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
272
273
274

        # Init running status
        self.waiting_queue: List[Req] = []
275
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
276
        self.running_batch: Optional[ScheduleBatch] = None
277
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
278
        self.cur_batch: Optional[ScheduleBatch] = None
279
280
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
        self.forward_ct = 0
        self.forward_ct_decode = 0
283
        self.num_generated_tokens = 0
284
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
285
        self.stream_interval = server_args.stream_interval
286
287
288
        self.current_stream = torch.get_device_module(self.device).current_stream()

        # Session info
289
        self.sessions: Dict[str, Session] = {}
290
291
292

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
293
294
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
295
        self.being_chunked_req = None
296
297
298
299
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
300
        # Init the grammar backend for constrained generation
301
        self.grammar_queue: List[Req] = []
302
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
303
304
305
306
307
308
309
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
310
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
327
328

        # Init new token estimation
329
330
331
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
332
333
334

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
335
336
            * server_args.schedule_conservativeness,
            1.0,
337
        )
338
339
340
341
342
343
344
345
346
347
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
350
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
351
        self.batch_is_full = False
352

Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
355
356
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
357
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
358

359
        # Init profiler
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
375

376
        # Init metrics stats
377
378
379
380
381
382
383
384
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
385

Lianmin Zheng's avatar
Lianmin Zheng committed
386
    def watchdog_thread(self):
387
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

402
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
403

404
    @torch.no_grad()
405
    def event_loop_normal(self):
406
        """A normal scheduler loop."""
407
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
408
409
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
410

411
            batch = self.get_next_batch_to_run()
412
413

            if self.server_args.enable_dp_attention:  # TODO: simplify this
Ke Bao's avatar
Ke Bao committed
414
415
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
416
            self.cur_batch = batch
417
418
419
420

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
421
            else:
422
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
423
                self.check_memory()
424
                self.new_token_ratio = self.init_new_token_ratio
425
426

            self.last_batch = batch
427

428
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
429
    def event_loop_overlap(self):
430
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
431
432
433
434
435
436
437
438
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
439

Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
442
443
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

444
                if self.last_batch is None:
445
                    # Create a dummy first batch to start the pipeline for overlap scheduler.
446
447
448
449
450
451
452
453
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
454
            if self.last_batch:
455
                # Process the results of the last batch
Lianmin Zheng's avatar
Lianmin Zheng committed
456
                tmp_batch, tmp_result = result_queue.popleft()
457
458
459
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
460
461
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
462
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
463
                self.check_memory()
464
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
465
466
467

            self.last_batch = batch

468
469
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
Ke Bao's avatar
Ke Bao committed
470
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
471
472
            recv_reqs = []

473
474
475
476
477
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
478
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
        else:
            recv_reqs = None
481

Ke Bao's avatar
Ke Bao committed
482
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
483
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
484
485
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
486
    def process_input_requests(self, recv_reqs: List):
487
488
489
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
490
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
491
492
493
494
495
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
Chayenne's avatar
Chayenne committed
496
497
            elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
                success, message = self.update_weights_from_disk(recv_req)
498
                self.send_to_tokenizer.send_pyobj(
Chayenne's avatar
Chayenne committed
499
                    UpdateWeightFromDiskReqOutput(success, message)
500
                )
501
502
503
504
505
506
507
508
509
510
            elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
                success, message = self.init_weights_update_group(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    InitWeightsUpdateGroupReqOutput(success, message)
                )
            elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
                success, message = self.update_weights_from_distributed(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromDistributedReqOutput(success, message)
                )
511
512
513
514
515
            elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
                success, message = self.update_weights_from_tensor(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromTensorReqOutput(success, message)
                )
516
517
518
            elif isinstance(recv_req, GetWeightsByNameReqInput):
                parameter = self.get_weights_by_name(recv_req)
                self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
519
520
521
522
523
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
524
            elif isinstance(recv_req, OpenSessionReqInput):
525
526
527
528
                session_id, success = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    OpenSessionReqOutput(session_id=session_id, success=success)
                )
529
530
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
531
532
533
534
535
536
537
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
538
        # Create a new request
539
540
541
542
543
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
544

Rin Intachuen's avatar
Rin Intachuen committed
545
546
547
548
549
550
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

551
552
553
554
555
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
556
557
558
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
                stream=recv_req.stream,
559
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
560
                input_embeds=recv_req.input_embeds,
561
                eos_token_ids=self.model_config.hf_eos_token_id,
562
563
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
564

565
566
567
568
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
569
                req.finished_reason = FINISH_ABORT(
570
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
571
572
573
574
                )
                self.waiting_queue.append(req)
                return
        else:
575
576
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
577
            req = session.create_req(recv_req, self.tokenizer)
578
579
580
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
581

582
        # Handle image inputs
583
        if recv_req.image_inputs is not None:
584
585
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
586
            req.origin_input_ids = self.pad_input_ids_func(
587
                req.origin_input_ids, image_inputs
588
            )
589
            req.extend_image_inputs(image_inputs)
590

591
592
593
594
            if len(req.origin_input_ids) >= self.max_req_input_len:
                logger.error(
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
595
                )
596
                req.origin_input_ids = [0]
597
                req.image_inputs = None
598
                req.sampling_params.max_new_tokens = 0
599
600
601
                req.finished_reason = FINISH_ABORT(
                    "Multimodal prompt is too long. Check server logs for details."
                )
602
603
604
                self.waiting_queue.append(req)
                return

605
        # Copy more attributes
606
607
608
609
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
610
            req.logprob_start_len = len(req.origin_input_ids) - 1
611
612

        # Truncate prompts that are too long
613
        if len(req.origin_input_ids) > self.max_req_input_len:
614
615
616
617
618
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
619

620
621
622
623
624
625
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
626
            self.max_req_len - len(req.origin_input_ids) - 1,
627
628
        )

629
630
631
632
633
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
634
            or req.sampling_params.ebnf is not None
635
636
637
638
639
640
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
641
642
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
643
644
645
646
647
648
649

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
650
651
652
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
653
654
655

    def handle_embedding_request(
        self,
656
        recv_req: TokenizedEmbeddingReqInput,
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

676
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
677
678
679
680
681
682
683
        self.tree_cache_metrics["total"] += (
            adder.log_input_tokens + adder.log_hit_tokens
        ) / 10**9
        self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
        tree_cache_hit_rate = (
            self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
        )
684
685
686
687
688
689
690
691
692
693
694
695
696

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
697
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
698
699
700
701
702
703
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
704
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
705
706
707
708
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
709
710
711
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
712
713
714
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
715
        self.num_generated_tokens = 0
716
        self.last_decode_stats_tic = time.time()
717
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
        logger.info(
            f"Decode batch. "
720
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
721
722
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
723
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
726
            f"#queue-req: {len(self.waiting_queue)}"
        )

727
728
729
730
731
732
733
734
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
735
736
737
738
739
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
740
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
741
                "KV cache pool leak detected!"
742
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
743
            )
744
745
746
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
747
748

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
749
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
750
                "Memory pool leak detected!"
751
752
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
753
            )
754
755
756
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
757

758
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
759
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
760
        if self.last_batch and self.last_batch.forward_mode.is_extend():
761
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
762
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
763
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
764
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
765
                # being chunked request keeps its rid but will get a new req_pool_idx
766
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
767
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
768

769
770
771
772
773
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
774

Lianmin Zheng's avatar
Lianmin Zheng committed
775
        # Run prefill first if possible
776
777
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
778
            return new_batch
779

780
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
781
        if self.running_batch is None:
782
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
783
        self.running_batch = self.update_running_batch(self.running_batch)
784
        return self.running_batch
785

Lianmin Zheng's avatar
Lianmin Zheng committed
786
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
787
        # Check if the grammar is ready in the grammar queue
788
        if self.grammar_queue:
789
            self.move_ready_grammar_requests()
790

Lianmin Zheng's avatar
Lianmin Zheng committed
791
792
793
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
794
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
795
796
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
797
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
798
        if running_bs >= self.max_running_requests:
799
            self.batch_is_full = True
800
801
802
803
804
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
805
        # Prefill policy
806
807
808
809
810
811
812
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
813
            running_bs if self.is_mixed_chunk else 0,
814
815
        )

816
817
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
818
            self.being_chunked_req.init_next_round_input()
819
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
820

Lianmin Zheng's avatar
Lianmin Zheng committed
821
        if self.lora_paths:
822
823
824
825
826
827
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

828
        # Get requests from the waiting queue to a new prefill batch
829
830
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
831
                self.lora_paths
832
833
834
835
836
837
838
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
839
                self.batch_is_full = True
840
841
                break

842
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
843
                self.batch_is_full = True
844
                break
845

846
847
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
848
849
850
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
851
                break
852
853
            if self.server_args.prefill_only_one_req:
                break
854

Lianmin Zheng's avatar
Lianmin Zheng committed
855
        # Update waiting queue
856
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
857
858
859
860
861
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
862

863
        if adder.new_being_chunked_req is not None:
864
            assert self.being_chunked_req is None
865
            self.being_chunked_req = adder.new_being_chunked_req
866

867
868
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
869

870
871
        # Print stats
        if self.tp_rank == 0:
872
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
873

Lianmin Zheng's avatar
Lianmin Zheng committed
874
        # Create a new batch
875
876
877
878
879
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
880
            self.model_config,
881
            self.enable_overlap,
882
            self.spec_algorithm,
883
        )
884
        new_batch.prepare_for_extend()
885

Lianmin Zheng's avatar
Lianmin Zheng committed
886
        # Mixed-style chunked prefill
887
888
889
890
891
892
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
893
894
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
895
                self.running_batch.prepare_for_decode()
896
897
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
898
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
899
900
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
901
902
903

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
904
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
905
        """Update the current running decoding batch."""
906
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
907
908

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
909

910
911
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
912
913
            self.batch_is_full = False
            return None
914

Lianmin Zheng's avatar
Lianmin Zheng committed
915
        # Check if decode out of memory
916
917
918
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
            test_retract and batch.batch_size() > 10
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
919
920
921
922
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
923
924
            if self.draft_worker:
                self.draft_worker.finish_request(retracted_reqs)
925

Lianmin Zheng's avatar
Lianmin Zheng committed
926
927
928
929
930
931
932
933
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
934
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
935
936
937
938
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
939
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
940
941
942
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
943
944
945
946
947
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
948
949

        # Update batch tensors
950
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
951
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
952
953

    def run_batch(self, batch: ScheduleBatch):
954
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
955
956
        self.forward_ct += 1

957
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
958
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
959
960
961
962
963
964
965
966
967
968
                if self.spec_algorithm.is_none():
                    model_worker_batch = batch.get_model_worker_batch()
                    logits_output, next_token_ids = (
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
                    logits_output, next_token_ids, model_worker_batch, spec_info = (
                        self.draft_worker.forward_batch_speculative_generation(batch)
                    )
                    batch.spec_info = spec_info
Ke Bao's avatar
Ke Bao committed
969
970
971
972
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
973
974
            else:
                logits_output = None
975
                if self.skip_tokenizer_init:
976
977
978
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
979
                else:
980
                    next_token_ids = torch.full((batch.batch_size(),), 0)
981
            batch.output_ids = next_token_ids
982
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
983
984
985
986
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
987
            ret = embeddings, model_worker_batch.bid
988
        return ret
Chayenne's avatar
Chayenne committed
989

Lianmin Zheng's avatar
Lianmin Zheng committed
990
991
992
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
993
994
            if batch.is_empty():
                self.running_batch = None
995
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
996
            self.process_batch_result_prefill(batch, result)
997
998
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
999
            self.current_stream.synchronize()
1000
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
1003
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1004

Lianmin Zheng's avatar
Lianmin Zheng committed
1005
        if self.is_generation:
1006
            logits_output, next_token_ids, bid = result
1007
1008

            if self.enable_overlap:
1009
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1010
1011
            else:
                # Move next_token_ids and logprobs to cpu
1012
                next_token_ids = next_token_ids.tolist()
1013
                if batch.return_logprob:
1014
                    logits_output.next_token_logprobs = (
1015
                        logits_output.next_token_logprobs.tolist()
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )

            # Check finish conditions
            logprob_pt = 0
1026
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1027
1028
1029
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1030
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1031
1032
1033
1034
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1035

1036
                if req.is_being_chunked <= 0:
1037
                    req.output_ids.append(next_token_id)
1038
1039
                    req.check_finished()

1040
                    if req.finished():
1041
                        self.tree_cache.cache_finished_req(req)
1042
1043
1044
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1045
1046
1047
1048
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1049
1050
1051

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1052
                        req.grammar.finished = req.finished()
1053
                else:
1054
                    # being chunked reqs' prefill is not finished
1055
                    req.is_being_chunked -= 1
1056
1057
1058
1059
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1060

1061
1062
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1063
                self.current_stream.synchronize()
1064
1065
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1066
        else:  # embedding or reward model
1067
1068
            embeddings, bid = result
            embeddings = embeddings.tolist()
1069
1070
1071

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1072
1073
1074
                if req.is_retracted:
                    continue

1075
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1076
1077
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1078
1079
1080
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1081
1082
1083
1084
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1085
                else:
1086
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1087
                    req.is_being_chunked -= 1
1088

Lianmin Zheng's avatar
Lianmin Zheng committed
1089
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1090

Lianmin Zheng's avatar
Lianmin Zheng committed
1091
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1092
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1093
1094
        self.num_generated_tokens += len(batch.reqs)

1095
        if self.enable_overlap:
1096
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1097
            next_token_logprobs = logits_output.next_token_logprobs
1098
1099
        else:
            next_token_ids = next_token_ids.tolist()
1100
1101
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1102

1103
1104
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1105
1106
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1107
1108
1109
            if req.is_retracted:
                continue

1110
            if self.enable_overlap and req.finished():
1111
                # Free the one delayed token
1112
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1113
1114
                continue

1115
1116
1117
1118
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1119
1120
1121
            req.check_finished()

            if req.finished():
1122
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1123
1124

            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
1125
1126
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1127
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
                    req.output_top_logprobs_val.append(
1129
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1130
1131
                    )
                    req.output_top_logprobs_idx.append(
1132
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1134

Lianmin Zheng's avatar
Lianmin Zheng committed
1135
1136
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1137
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1138

1139
1140
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1141
            self.current_stream.synchronize()
1142
1143
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1144
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1145

1146
1147
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1148
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1149
1150
1151
1152
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1153
            self.log_decode_stats()
1154

1155
1156
1157
1158
1159
1160
1161
1162
1163
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1164
1165
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1166
1167
1168
1169
1170
1171
1172

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

Lianmin Zheng's avatar
Lianmin Zheng committed
1173
1174
        if req.input_token_logprobs_val is None:
            input_token_logprobs_val = output.input_token_logprobs[
1175
1176
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1177
1178

            input_token_logprobs_idx = req.fill_ids[
1179
1180
1181
1182
1183
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1184
1185
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1186
            input_token_logprobs_idx = [
1187
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1188
                for x in input_token_logprobs_idx
1189
1190
            ]

1191
1192
1193
            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
Lianmin Zheng's avatar
Lianmin Zheng committed
1194
1195
1196
1197
1198
                input_token_logprobs_val = [None] + input_token_logprobs_val
                input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx

            req.input_token_logprobs_val = input_token_logprobs_val
            req.input_token_logprobs_idx = input_token_logprobs_idx
1199
1200
1201

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
            req.output_token_logprobs_val.extend(
                output.input_token_logprobs[
                    pt
                    + num_input_logprobs
                    - 1
                    - req.last_update_decode_tokens : pt
                    + num_input_logprobs
                    - 1
                ],
            )
            req.output_token_logprobs_idx.extend(
                req.fill_ids[
                    len(req.fill_ids)
                    - req.last_update_decode_tokens : len(req.fill_ids)
                ]
1217
1218
1219
            )

        if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1220
1221
1222
            if req.input_top_logprobs_val is None:
                req.input_top_logprobs_val = output.input_top_logprobs_val[i]
                req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1223
                if req.logprob_start_len == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1224
1225
                    req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
                    req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1226
1227

            if req.last_update_decode_tokens != 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1228
1229
                req.output_top_logprobs_val.extend(
                    output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1230
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1231
1232
1233
                req.output_top_logprobs_idx.extend(
                    output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
                )
1234
1235
1236

            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1237
1238
1239

        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1240
1241
1242
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1243
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1244
1245
1246
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1247
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
1248
            vids = []
1249
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
1251
            decode_ids_list = []
            read_offsets = []
1252
            output_ids = []
1253
            origin_input_ids = []
1254

Lianmin Zheng's avatar
Lianmin Zheng committed
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
                normalized_prompt_logprob = []
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    normalized_prompt_logprob
                ) = None

            for req in reqs:
                if req is skip_req:
                    continue
1284

Lianmin Zheng's avatar
Lianmin Zheng committed
1285
1286
1287
1288
1289
1290
1291
1292
                # TODO(lianmin): revisit this for overlap + retract + stream
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
                    or (not req.stream and len(req.output_ids) % 50 == 0)
                ):
1293
1294
1295
                    if self.draft_worker and req.finished():
                        self.draft_worker.finish_request(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
1296
1297
1298
1299
1300
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
                    vids.append(req.vid)
1301
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1302
1303
1304
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1305
                    if self.skip_tokenizer_init or self.server_args.return_token_ids:
1306
                        output_ids.append(req.output_ids)
1307
1308
1309
1310
1311
1312
                    else:
                        output_ids = None
                    if self.server_args.return_token_ids:
                        origin_input_ids.append(req.origin_input_ids)
                    else:
                        origin_input_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1313
1314
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1315
1316
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
                        normalized_prompt_logprob.append(req.normalized_prompt_logprob)

            # Send to detokenizer
            if rids:
1336
                self.send_to_detokenizer.send_pyobj(
1337
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1338
1339
1340
                        rids,
                        finished_reasons,
                        vids,
1341
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1342
1343
                        decode_ids_list,
                        read_offsets,
1344
                        origin_input_ids,
1345
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
                        normalized_prompt_logprob,
1361
1362
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
                assert req.finished()
                rids.append(req.rid)
                finished_reasons.append(req.finished_reason.to_json())
                embeddings.append(req.embedding)
                prompt_tokens.append(len(req.origin_input_ids))
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1375

1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1427
            self.spec_algorithm,
1428
1429
1430
1431
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1456
    def flush_cache(self):
1457
        """Flush the memory pool and cache."""
1458
1459
1460
1461
1462
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1463
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1464
                self.grammar_backend.reset()
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1489
1490
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1491
1492
1493
1494

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1495
                if req.rid == recv_req.rid and not req.finished():
1496
1497
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1498
1499
                    break

Chayenne's avatar
Chayenne committed
1500
1501
1502
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1503
1504
1505
1506
1507
1508
1509
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
        return success, message

    def update_weights_from_distributed(
        self, recv_req: UpdateWeightsFromDistributedReqInput
    ):
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1538
1539
1540
1541
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
        return parameter

1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1556
    def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1557
1558
1559
1560
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1561
1562
1563
1564
            return session_id, False
        elif session_id is None:
            logger.warning(f"session id is None, cannot open.")
            return session_id, False
1565
1566
1567
1568
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1569
            return session_id, True
1570
1571
1572
1573
1574
1575
1576
1577
1578

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1579
1580
1581
1582
1583
1584

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1585
    dp_rank: Optional[int],
1586
    pipe_writer,
1587
):
1588
1589
    setproctitle.setproctitle("sglang::scheduler")

1590
1591
1592
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1593

1594
    # Configue the logger
1595
1596
1597
1598
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1599
    suppress_other_loggers()
1600

1601
    # Set cpu affinity to this gpu process
1602
1603
1604
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1605
    parent_process = psutil.Process().parent()
1606

1607
    # Create a scheduler and run the event loop
1608
    try:
1609
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1610
1611
1612
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1613
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1614
1615
1616
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1617
    except Exception:
1618
1619
1620
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)