scheduler.py 65.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import threading
21
22
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from types import SimpleNamespace
26
from typing import Dict, List, Optional, Tuple
27

28
import psutil
29
import setproctitle
30
import torch
31
32
import zmq

33
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
34
from sglang.srt.configs.model_config import ModelConfig
35
36
37
38
39
40
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
41
    CloseSessionReqInput,
42
    FlushCacheReq,
43
44
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
45
46
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
47
48
    OpenSessionReqInput,
    OpenSessionReqOutput,
49
    ProfileReq,
50
51
52
53
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
54
55
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
56
57
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
58
59
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
60
61
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
62
63
64
65
66
67
68
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
69
    global_server_args_dict,
70
)
71
72
73
74
75
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
76
from sglang.srt.managers.session_controller import Session
77
from sglang.srt.managers.tp_worker import TpModelWorker
78
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
79
80
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
81
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
82
from sglang.srt.model_executor.forward_batch_info import ForwardMode
83
from sglang.srt.server_args import PortArgs, ServerArgs
84
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
85
86
87
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
88
    crash_on_warnings,
89
    get_bool_env_var,
90
    get_zmq_socket,
91
    set_gpu_proc_affinity,
92
93
94
    set_random_seed,
    suppress_other_loggers,
)
95
from sglang.torch_memory_saver_adapter import TorchMemorySaverAdapter
96
97
98
99
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

100
# Test retract decode for debugging purposes
101
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
102

103
104
105
106
107
108
109
110
111
112

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
113
        dp_rank: Optional[int],
114
115
    ):
        # Parse args
116
        self.server_args = server_args
117
118
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
119
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        self.disable_jump_forward = server_args.disable_jump_forward
121
122
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
123
        self.enable_overlap = not server_args.disable_overlap_schedule
124
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
125
        self.enable_metrics = server_args.enable_metrics
126
127
128
129
130
131
132
133
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
        self.decode_mem_cache_buf_multiplier = (
            self.server_args.speculative_num_draft_tokens
            if not self.spec_algorithm.is_none()
            else 1
        )
134
135
136
137

        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
138
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
139
140
141
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
142
143
144
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
145

146
            if server_args.skip_tokenizer_init:
147
                # Directly send to the TokenizerManager
148
149
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
150
151
                )
            else:
152
                # Send to the DetokenizerManager
153
154
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
155
                )
156
        else:
157
            self.recv_from_tokenizer = None
158
159
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
160
161
162
163

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
164
            trust_remote_code=server_args.trust_remote_code,
165
            revision=server_args.revision,
166
            context_length=server_args.context_length,
167
168
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
169
170
            dtype=server_args.dtype,
            quantization=server_args.quantization,
171
        )
172
        self.is_generation = self.model_config.is_generation
173
174
175
176

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
177
            if self.model_config.is_multimodal:
178
179
180
181
182
183
184
185
186
187
188
189
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
190

191
192
193
194
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
195

196
197
198
199
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

200
201
        if self.enable_overlap:
            self.disable_jump_forward = True
202

203
        # Launch a tensor parallel worker
204
        if self.enable_overlap:
205
            TpWorkerClass = TpModelWorkerClient
206
207
        else:
            TpWorkerClass = TpModelWorker
208

209
        self.tp_worker = TpWorkerClass(
210
            server_args=server_args,
211
212
            gpu_id=gpu_id,
            tp_rank=tp_rank,
213
            dp_rank=dp_rank,
214
            nccl_port=port_args.nccl_port,
215
        )
216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        # Launch worker for speculative decoding if need
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

232
        # Get token and memory info from the model worker
233
234
235
236
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
237
            self.max_req_len,
238
239
            self.max_req_input_len,
            self.random_seed,
240
            self.device,
241
242
243
244
245
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
246
247
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
248
        global_server_args_dict.update(worker_global_server_args_dict)
249
250
251
252
253
254
255
256
257
258
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

259
260
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
277
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
278
279
280

        # Init running status
        self.waiting_queue: List[Req] = []
281
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
282
        self.running_batch: Optional[ScheduleBatch] = None
283
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
284
        self.cur_batch: Optional[ScheduleBatch] = None
285
286
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
        self.forward_ct = 0
        self.forward_ct_decode = 0
289
        self.num_generated_tokens = 0
290
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
291
        self.stream_interval = server_args.stream_interval
292
293
294
        self.current_stream = torch.get_device_module(self.device).current_stream()

        # Session info
295
        self.sessions: Dict[str, Session] = {}
296
297
298

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
299
300
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
301
        self.being_chunked_req = None
302
303
304
305
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
306
        # Init the grammar backend for constrained generation
307
        self.grammar_queue: List[Req] = []
308
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
312
313
314
315
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
316
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
333
334

        # Init new token estimation
335
336
337
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
338
339
340

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
341
342
            * server_args.schedule_conservativeness,
            1.0,
343
        )
344
345
346
347
348
349
350
351
352
353
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
356
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
357
        self.batch_is_full = False
358

Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
362
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
363
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
364

365
366
367
368
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

369
        # Init profiler
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
385

386
        # Init metrics stats
387
388
389
390
391
392
393
394
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
395

Lianmin Zheng's avatar
Lianmin Zheng committed
396
    def watchdog_thread(self):
397
        """A watch dog thread that will try to kill the server itself if one batch takes too long."""
Lianmin Zheng's avatar
Lianmin Zheng committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

412
413
        # Wait sometimes so that the parent process can print the error.
        time.sleep(5)
414
        self.parent_process.send_signal(signal.SIGQUIT)
Lianmin Zheng's avatar
Lianmin Zheng committed
415

416
    @torch.no_grad()
417
    def event_loop_normal(self):
418
        """A normal scheduler loop."""
419
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
422

423
            batch = self.get_next_batch_to_run()
424
425

            if self.server_args.enable_dp_attention:  # TODO: simplify this
Ke Bao's avatar
Ke Bao committed
426
427
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
428
            self.cur_batch = batch
429
430
431
432

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
433
            else:
434
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
435
                self.check_memory()
436
                self.new_token_ratio = self.init_new_token_ratio
437
438

            self.last_batch = batch
439

440
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
441
    def event_loop_overlap(self):
442
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
443
444
445
446
447
448
449
450
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
451

Lianmin Zheng's avatar
Lianmin Zheng committed
452
453
454
455
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

456
                if self.last_batch is None:
457
                    # Create a dummy first batch to start the pipeline for overlap scheduler.
458
459
460
461
462
463
464
465
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
466
            if self.last_batch:
467
                # Process the results of the last batch
Lianmin Zheng's avatar
Lianmin Zheng committed
468
                tmp_batch, tmp_result = result_queue.popleft()
469
470
471
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
472
473
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
474
                # When the server is idle, so self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
475
                self.check_memory()
476
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479

            self.last_batch = batch

480
481
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
Ke Bao's avatar
Ke Bao committed
482
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
483
484
            recv_reqs = []

485
486
487
488
489
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
490
                recv_reqs.append(recv_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
491
492
        else:
            recv_reqs = None
493

Ke Bao's avatar
Ke Bao committed
494
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
495
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
496
497
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
498
    def process_input_requests(self, recv_reqs: List):
499
500
501
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
502
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
503
504
505
506
507
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
Chayenne's avatar
Chayenne committed
508
509
            elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
                success, message = self.update_weights_from_disk(recv_req)
510
                self.send_to_tokenizer.send_pyobj(
Chayenne's avatar
Chayenne committed
511
                    UpdateWeightFromDiskReqOutput(success, message)
512
                )
513
514
515
516
517
518
519
520
521
522
            elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
                success, message = self.init_weights_update_group(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    InitWeightsUpdateGroupReqOutput(success, message)
                )
            elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
                success, message = self.update_weights_from_distributed(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromDistributedReqOutput(success, message)
                )
523
524
525
526
527
            elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
                success, message = self.update_weights_from_tensor(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    UpdateWeightsFromTensorReqOutput(success, message)
                )
528
529
530
            elif isinstance(recv_req, GetWeightsByNameReqInput):
                parameter = self.get_weights_by_name(recv_req)
                self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
531
532
533
534
535
536
            elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
                self.release_memory_occupation()
                self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
            elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
                self.resume_memory_occupation()
                self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
537
538
539
540
541
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
542
            elif isinstance(recv_req, OpenSessionReqInput):
543
544
545
546
                session_id, success = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(
                    OpenSessionReqOutput(session_id=session_id, success=success)
                )
547
548
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
549
550
551
552
553
554
555
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
556
        # Create a new request
557
558
559
560
561
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
562

Rin Intachuen's avatar
Rin Intachuen committed
563
564
565
566
567
568
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

569
570
571
572
573
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
574
575
576
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
                stream=recv_req.stream,
577
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
578
                input_embeds=recv_req.input_embeds,
579
                eos_token_ids=self.model_config.hf_eos_token_id,
580
581
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
582

583
584
585
586
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
587
                req.finished_reason = FINISH_ABORT(
588
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
589
590
591
592
                )
                self.waiting_queue.append(req)
                return
        else:
593
594
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
595
            req = session.create_req(recv_req, self.tokenizer)
596
597
598
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
599

600
        # Handle image inputs
601
        if recv_req.image_inputs is not None:
602
603
            image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
604
            req.origin_input_ids = self.pad_input_ids_func(
605
                req.origin_input_ids, image_inputs
606
            )
607
            req.extend_image_inputs(image_inputs)
608

609
610
611
612
            if len(req.origin_input_ids) >= self.max_req_input_len:
                logger.error(
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
613
                )
614
                req.origin_input_ids = [0]
615
                req.image_inputs = None
616
                req.sampling_params.max_new_tokens = 0
617
618
619
                req.finished_reason = FINISH_ABORT(
                    "Multimodal prompt is too long. Check server logs for details."
                )
620
621
622
                self.waiting_queue.append(req)
                return

623
        # Copy more attributes
624
625
626
627
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
628
            req.logprob_start_len = len(req.origin_input_ids) - 1
629
630

        # Truncate prompts that are too long
631
        if len(req.origin_input_ids) > self.max_req_input_len:
632
633
634
635
636
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
637

638
639
640
641
642
643
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
644
            self.max_req_len - len(req.origin_input_ids) - 1,
645
646
        )

647
648
649
650
651
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
652
            or req.sampling_params.ebnf is not None
653
654
655
656
657
658
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
659
660
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
661
662
663
664
665
666
667

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
668
669
670
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
671
672
673

    def handle_embedding_request(
        self,
674
        recv_req: TokenizedEmbeddingReqInput,
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

694
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
695
696
697
698
699
700
701
        self.tree_cache_metrics["total"] += (
            adder.log_input_tokens + adder.log_hit_tokens
        ) / 10**9
        self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
        tree_cache_hit_rate = (
            self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
        )
702
703
704
705
706
707
708
709
710
711
712
713
714

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
715
            f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
716
717
718
719
720
721
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
722
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
723
724
725
726
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
729
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
730
731
732
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
733
        self.num_generated_tokens = 0
734
        self.last_decode_stats_tic = time.time()
735
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
736
737
        logger.info(
            f"Decode batch. "
738
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
739
740
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
741
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
742
743
744
            f"#queue-req: {len(self.waiting_queue)}"
        )

745
746
747
748
749
750
751
752
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755
756
757
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
758
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
759
                "KV cache pool leak detected!"
760
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
761
            )
762
763
764
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
765
766

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
767
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
768
                "Memory pool leak detected!"
769
770
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
771
            )
772
773
774
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
775

776
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
777
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
778
        if self.last_batch and self.last_batch.forward_mode.is_extend():
779
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
780
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
781
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
782
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
783
                # being chunked request keeps its rid but will get a new req_pool_idx
784
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
785
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
786

787
788
789
790
791
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
792

Lianmin Zheng's avatar
Lianmin Zheng committed
793
        # Run prefill first if possible
794
795
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
796
            return new_batch
797

798
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
799
        if self.running_batch is None:
800
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
801
        self.running_batch = self.update_running_batch(self.running_batch)
802
        return self.running_batch
803

Lianmin Zheng's avatar
Lianmin Zheng committed
804
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
805
        # Check if the grammar is ready in the grammar queue
806
        if self.grammar_queue:
807
            self.move_ready_grammar_requests()
808

Lianmin Zheng's avatar
Lianmin Zheng committed
809
810
811
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
812
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
813
814
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
815
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
816
        if running_bs >= self.max_running_requests:
817
            self.batch_is_full = True
818
819
820
821
822
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
823
        # Prefill policy
824
825
826
827
828
829
830
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
831
            running_bs if self.is_mixed_chunk else 0,
832
833
        )

834
835
        has_being_chunked = self.being_chunked_req is not None
        if has_being_chunked:
836
            self.being_chunked_req.init_next_round_input()
837
            self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
838

Lianmin Zheng's avatar
Lianmin Zheng committed
839
        if self.lora_paths:
840
841
842
843
844
845
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

846
        # Get requests from the waiting queue to a new prefill batch
847
848
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
849
                self.lora_paths
850
851
852
853
854
855
856
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
857
                self.batch_is_full = True
858
859
                break

860
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
861
                self.batch_is_full = True
862
                break
863

864
865
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
866
867
868
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
869
                break
870
871
            if self.server_args.prefill_only_one_req:
                break
872

Lianmin Zheng's avatar
Lianmin Zheng committed
873
        # Update waiting queue
874
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
875
876
877
878
879
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
880

881
        if adder.new_being_chunked_req is not None:
882
            assert self.being_chunked_req is None
883
            self.being_chunked_req = adder.new_being_chunked_req
884

885
886
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
887

888
889
        # Print stats
        if self.tp_rank == 0:
890
            self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
891

Lianmin Zheng's avatar
Lianmin Zheng committed
892
        # Create a new batch
893
894
895
896
897
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
898
            self.model_config,
899
            self.enable_overlap,
900
            self.spec_algorithm,
901
        )
902
        new_batch.prepare_for_extend()
903

Lianmin Zheng's avatar
Lianmin Zheng committed
904
        # Mixed-style chunked prefill
905
906
907
908
909
910
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
911
912
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
913
                self.running_batch.prepare_for_decode()
914
915
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
916
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
917
918
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
919
920
921

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
922
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
923
        """Update the current running decoding batch."""
924
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
925
926

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
927

928
929
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
930
931
            self.batch_is_full = False
            return None
932

Lianmin Zheng's avatar
Lianmin Zheng committed
933
        # Check if decode out of memory
934
935
936
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
            test_retract and batch.batch_size() > 10
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
937
938
939
940
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
941
942
            if self.draft_worker:
                self.draft_worker.finish_request(retracted_reqs)
943

Lianmin Zheng's avatar
Lianmin Zheng committed
944
945
946
947
948
949
950
951
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
952
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
953
954
955
956
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
957
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
958
959
960
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
961
962
963
964
965
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
966
967

        # Update batch tensors
968
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
969
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
970
971

    def run_batch(self, batch: ScheduleBatch):
972
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
973
974
        self.forward_ct += 1

975
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
976
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
977
978
979
980
981
982
                if self.spec_algorithm.is_none():
                    model_worker_batch = batch.get_model_worker_batch()
                    logits_output, next_token_ids = (
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
983
984
985
986
987
988
989
                    (
                        logits_output,
                        next_token_ids,
                        model_worker_batch,
                        num_accepted_tokens,
                    ) = self.draft_worker.forward_batch_speculative_generation(batch)
                    self.num_generated_tokens += num_accepted_tokens
Ke Bao's avatar
Ke Bao committed
990
991
992
993
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
994
995
            else:
                logits_output = None
996
                if self.skip_tokenizer_init:
997
998
999
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1000
                else:
1001
                    next_token_ids = torch.full((batch.batch_size(),), 0)
1002
            batch.output_ids = next_token_ids
1003
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1004
1005
1006
1007
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1008
            ret = embeddings, model_worker_batch.bid
1009
        return ret
Chayenne's avatar
Chayenne committed
1010

Lianmin Zheng's avatar
Lianmin Zheng committed
1011
1012
1013
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1014
1015
            if batch.is_empty():
                self.running_batch = None
1016
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
            self.process_batch_result_prefill(batch, result)
1018
1019
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1020
            self.current_stream.synchronize()
1021
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1022
1023

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
1024
        skip_stream_req = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1025

Lianmin Zheng's avatar
Lianmin Zheng committed
1026
        if self.is_generation:
1027
            logits_output, next_token_ids, bid = result
1028
1029

            if self.enable_overlap:
1030
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1031
1032
            else:
                # Move next_token_ids and logprobs to cpu
1033
                next_token_ids = next_token_ids.tolist()
1034
                if batch.return_logprob:
1035
                    logits_output.next_token_logprobs = (
1036
                        logits_output.next_token_logprobs.tolist()
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )

            # Check finish conditions
            logprob_pt = 0
1047
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1048
1049
1050
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1051
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
1052
1053
1054
1055
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
1056

1057
                if req.is_being_chunked <= 0:
1058
                    req.output_ids.append(next_token_id)
1059
1060
                    req.check_finished()

1061
                    if req.finished():
1062
                        self.tree_cache.cache_finished_req(req)
1063
1064
1065
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1066
1067
1068
1069
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1070
1071
1072

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1073
                        req.grammar.finished = req.finished()
1074
                else:
1075
                    # being chunked reqs' prefill is not finished
1076
                    req.is_being_chunked -= 1
1077
1078
1079
1080
                    # There is only at most one request being currently chunked.
                    # Because this request does not finish prefill,
                    # we don't want to stream the request currently being chunked.
                    skip_stream_req = req
1081

1082
1083
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1084
                self.current_stream.synchronize()
1085
1086
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1087
        else:  # embedding or reward model
1088
1089
            embeddings, bid = result
            embeddings = embeddings.tolist()
1090
1091
1092

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1093
1094
1095
                if req.is_retracted:
                    continue

1096
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1097
1098
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1099
1100
1101
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1102
1103
1104
1105
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1106
                else:
1107
                    # being chunked reqs' prefill is not finished
Lianmin Zheng's avatar
Lianmin Zheng committed
1108
                    req.is_being_chunked -= 1
1109

Lianmin Zheng's avatar
Lianmin Zheng committed
1110
        self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1111

Lianmin Zheng's avatar
Lianmin Zheng committed
1112
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1113
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1114
1115
        self.num_generated_tokens += len(batch.reqs)

1116
        if self.enable_overlap:
1117
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1118
            next_token_logprobs = logits_output.next_token_logprobs
1119
1120
        else:
            next_token_ids = next_token_ids.tolist()
1121
1122
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1123

1124
1125
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1126
1127
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1128
1129
1130
            if req.is_retracted:
                continue

1131
            if self.enable_overlap and req.finished():
1132
                # Free the one delayed token
1133
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
1135
                continue

1136
1137
1138
1139
            if batch.spec_algorithm.is_none():
                # speculative worker will solve the output_ids in speculative decoding
                req.output_ids.append(next_token_id)

Lianmin Zheng's avatar
Lianmin Zheng committed
1140
1141
1142
            req.check_finished()

            if req.finished():
1143
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1144
1145

            if req.return_logprob:
Lianmin Zheng's avatar
Lianmin Zheng committed
1146
1147
                req.output_token_logprobs_val.append(next_token_logprobs[i])
                req.output_token_logprobs_idx.append(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1148
                if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1149
                    req.output_top_logprobs_val.append(
1150
                        logits_output.next_token_top_logprobs_val[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1151
1152
                    )
                    req.output_top_logprobs_idx.append(
1153
                        logits_output.next_token_top_logprobs_idx[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1154
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1155

Lianmin Zheng's avatar
Lianmin Zheng committed
1156
1157
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
1158
                req.grammar.finished = req.finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1159

1160
1161
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1162
            self.current_stream.synchronize()
1163
1164
            batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1165
        self.stream_output(batch.reqs, batch.return_logprob)
Lianmin Zheng's avatar
Lianmin Zheng committed
1166

1167
1168
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1169
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1170
1171
1172
1173
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1174
            self.log_decode_stats()
1175

1176
1177
1178
1179
1180
1181
1182
1183
1184
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1185
1186
        req.output_token_logprobs_val.append(output.next_token_logprobs[i])
        req.output_token_logprobs_idx.append(next_token_ids[i])
1187
1188
1189
1190
1191
1192
1193

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

Lianmin Zheng's avatar
Lianmin Zheng committed
1194
1195
        if req.input_token_logprobs_val is None:
            input_token_logprobs_val = output.input_token_logprobs[
1196
1197
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1198
1199

            input_token_logprobs_idx = req.fill_ids[
1200
1201
1202
1203
1204
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
1205
1206
            # Clip the padded hash values from image tokens.
            # Otherwise, it will lead to detokenization errors.
Lianmin Zheng's avatar
Lianmin Zheng committed
1207
            input_token_logprobs_idx = [
1208
                x if x < self.model_config.vocab_size - 1 else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1209
                for x in input_token_logprobs_idx
1210
1211
            ]

1212
1213
1214
            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
Lianmin Zheng's avatar
Lianmin Zheng committed
1215
1216
1217
1218
1219
                input_token_logprobs_val = [None] + input_token_logprobs_val
                input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx

            req.input_token_logprobs_val = input_token_logprobs_val
            req.input_token_logprobs_idx = input_token_logprobs_idx
1220
1221
1222

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
            req.output_token_logprobs_val.extend(
                output.input_token_logprobs[
                    pt
                    + num_input_logprobs
                    - 1
                    - req.last_update_decode_tokens : pt
                    + num_input_logprobs
                    - 1
                ],
            )
            req.output_token_logprobs_idx.extend(
                req.fill_ids[
                    len(req.fill_ids)
                    - req.last_update_decode_tokens : len(req.fill_ids)
                ]
1238
1239
1240
            )

        if req.top_logprobs_num > 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1241
1242
1243
            if req.input_top_logprobs_val is None:
                req.input_top_logprobs_val = output.input_top_logprobs_val[i]
                req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
1244
                if req.logprob_start_len == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1245
1246
                    req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
                    req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
1247
1248

            if req.last_update_decode_tokens != 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1249
1250
                req.output_top_logprobs_val.extend(
                    output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
1251
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1252
1253
1254
                req.output_top_logprobs_idx.extend(
                    output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
                )
1255
1256
1257

            req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
            req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1258
1259
1260

        return num_input_logprobs

Lianmin Zheng's avatar
Lianmin Zheng committed
1261
1262
1263
    def stream_output(
        self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
    ):
1264
        """Stream the output to detokenizer."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1265
1266
1267
        rids = []
        finished_reasons: List[BaseFinishReason] = []

1268
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
1269
            vids = []
1270
            decoded_texts = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1271
1272
            decode_ids_list = []
            read_offsets = []
1273
            output_ids = []
1274

Lianmin Zheng's avatar
Lianmin Zheng committed
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
            skip_special_tokens = []
            spaces_between_special_tokens = []
            no_stop_trim = []
            prompt_tokens = []
            completion_tokens = []
            cached_tokens = []

            if return_logprob:
                input_token_logprobs_val = []
                input_token_logprobs_idx = []
                output_token_logprobs_val = []
                output_token_logprobs_idx = []
                input_top_logprobs_val = []
                input_top_logprobs_idx = []
                output_top_logprobs_val = []
                output_top_logprobs_idx = []
                normalized_prompt_logprob = []
            else:
                input_token_logprobs_val = input_token_logprobs_idx = (
                    output_token_logprobs_val
                ) = output_token_logprobs_idx = input_top_logprobs_val = (
                    input_top_logprobs_idx
                ) = output_top_logprobs_val = output_top_logprobs_idx = (
                    normalized_prompt_logprob
                ) = None

            for req in reqs:
                if req is skip_req:
                    continue
1304

Lianmin Zheng's avatar
Lianmin Zheng committed
1305
1306
1307
1308
1309
1310
1311
1312
                # TODO(lianmin): revisit this for overlap + retract + stream
                if (
                    req.finished()
                    # If stream, follow the given stream_interval
                    or (req.stream and len(req.output_ids) % self.stream_interval == 0)
                    # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
                    or (not req.stream and len(req.output_ids) % 50 == 0)
                ):
1313
1314
1315
                    if self.draft_worker and req.finished():
                        self.draft_worker.finish_request(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
1316
1317
1318
1319
1320
                    rids.append(req.rid)
                    finished_reasons.append(
                        req.finished_reason.to_json() if req.finished_reason else None
                    )
                    vids.append(req.vid)
1321
                    decoded_texts.append(req.decoded_text)
Lianmin Zheng's avatar
Lianmin Zheng committed
1322
1323
1324
                    decode_ids, read_offset = req.init_incremental_detokenize()
                    decode_ids_list.append(decode_ids)
                    read_offsets.append(read_offset)
1325
                    if self.skip_tokenizer_init:
1326
                        output_ids.append(req.output_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
1328
                    skip_special_tokens.append(req.sampling_params.skip_special_tokens)
                    spaces_between_special_tokens.append(
1329
1330
                        req.sampling_params.spaces_between_special_tokens
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
                    no_stop_trim.append(req.sampling_params.no_stop_trim)

                    prompt_tokens.append(len(req.origin_input_ids))
                    completion_tokens.append(len(req.output_ids))
                    cached_tokens.append(req.cached_tokens)

                    if return_logprob:
                        input_token_logprobs_val.append(req.input_token_logprobs_val)
                        input_token_logprobs_idx.append(req.input_token_logprobs_idx)
                        output_token_logprobs_val.append(req.output_token_logprobs_val)
                        output_token_logprobs_idx.append(req.output_token_logprobs_idx)
                        input_top_logprobs_val.append(req.input_top_logprobs_val)
                        input_top_logprobs_idx.append(req.input_top_logprobs_idx)
                        output_top_logprobs_val.append(req.output_top_logprobs_val)
                        output_top_logprobs_idx.append(req.output_top_logprobs_idx)
                        normalized_prompt_logprob.append(req.normalized_prompt_logprob)

            # Send to detokenizer
            if rids:
1350
                self.send_to_detokenizer.send_pyobj(
1351
                    BatchTokenIDOut(
Lianmin Zheng's avatar
Lianmin Zheng committed
1352
1353
1354
                        rids,
                        finished_reasons,
                        vids,
1355
                        decoded_texts,
Lianmin Zheng's avatar
Lianmin Zheng committed
1356
1357
                        decode_ids_list,
                        read_offsets,
1358
                        output_ids,
Lianmin Zheng's avatar
Lianmin Zheng committed
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
                        skip_special_tokens,
                        spaces_between_special_tokens,
                        no_stop_trim,
                        prompt_tokens,
                        completion_tokens,
                        cached_tokens,
                        input_token_logprobs_val,
                        input_token_logprobs_idx,
                        output_token_logprobs_val,
                        output_token_logprobs_idx,
                        input_top_logprobs_val,
                        input_top_logprobs_idx,
                        output_top_logprobs_val,
                        output_top_logprobs_idx,
                        normalized_prompt_logprob,
1374
1375
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1376
1377
1378
1379
        else:  # embedding or reward model
            embeddings = []
            prompt_tokens = []
            for req in reqs:
1380
1381
1382
1383
1384
                if req.finished():
                    rids.append(req.rid)
                    finished_reasons.append(req.finished_reason.to_json())
                    embeddings.append(req.embedding)
                    prompt_tokens.append(len(req.origin_input_ids))
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
1386
1387
            self.send_to_detokenizer.send_pyobj(
                BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
            )
1388

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
            group=self.tp_cpu_group,
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1440
            self.spec_algorithm,
1441
1442
1443
1444
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1469
    def flush_cache(self):
1470
        """Flush the memory pool and cache."""
1471
1472
1473
1474
1475
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1476
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1477
                self.grammar_backend.reset()
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]
1502
1503
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1504
1505
1506
1507

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1508
                if req.rid == recv_req.rid and not req.finished():
1509
1510
                    logger.debug(f"Abort running request. {req.rid=}")
                    req.to_abort = True
1511
1512
                    break

Chayenne's avatar
Chayenne committed
1513
1514
1515
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1516
1517
1518
1519
1520
1521
1522
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1523
1524
1525
1526
1527
1528
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
        return success, message

    def update_weights_from_distributed(
1529
1530
1531
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1532
1533
1534
1535
1536
1537
1538
1539
1540
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1552
1553
1554
1555
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
        return parameter

1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
    def release_memory_occupation(self):
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()

    def resume_memory_occupation(self):
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state

1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1584
    def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1585
1586
1587
1588
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1589
1590
1591
1592
            return session_id, False
        elif session_id is None:
            logger.warning(f"session id is None, cannot open.")
            return session_id, False
1593
1594
1595
1596
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1597
            return session_id, True
1598
1599
1600
1601
1602
1603
1604
1605
1606

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1607

1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1622
1623
1624
1625
1626
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1627
    dp_rank: Optional[int],
1628
    pipe_writer,
1629
):
1630
    setproctitle.setproctitle("sglang::scheduler")
1631
    faulthandler.enable()
1632

1633
1634
1635
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1636

1637
    # Configue the logger
1638
1639
1640
1641
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1642
    suppress_other_loggers()
1643

1644
    # Set cpu affinity to this gpu process
1645
1646
1647
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1648
    parent_process = psutil.Process().parent()
1649

1650
    # Create a scheduler and run the event loop
1651
    try:
1652
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1653
1654
1655
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1656
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1657
1658
1659
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1660
    except Exception:
1661
1662
1663
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)