scheduler.py 54.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""A scheduler that manages a tensor parallel GPU worker."""

import logging
17
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import threading
19
20
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from concurrent import futures
23
from types import SimpleNamespace
24
from typing import List, Optional
25

26
import torch
27
28
import zmq

29
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
30
from sglang.srt.configs.model_config import ModelConfig
31
32
33
34
35
36
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
37
    CloseSessionReqInput,
38
    FlushCacheReq,
39
40
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
41
42
    OpenSessionReqInput,
    OpenSessionReqOutput,
43
    ProfileReq,
44
45
46
47
48
49
50
51
52
53
54
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
55
    global_server_args_dict,
56
)
57
58
59
60
61
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
62
from sglang.srt.managers.session_controller import Session
63
from sglang.srt.managers.tp_worker import TpModelWorker
64
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
65
66
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
67
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
68
from sglang.srt.model_executor.forward_batch_info import ForwardMode
69
from sglang.srt.server_args import PortArgs, ServerArgs
70
71
72
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
73
    crash_on_warnings,
74
    get_zmq_socket,
75
76
77
78
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
79
80
81
82
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

83
# Test retract decode
84
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
85

86
87
88
89
90
91
92
93
94
95

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
96
        dp_rank: Optional[int],
97
98
    ):
        # Parse args
99
        self.server_args = server_args
100
101
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
102
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
103
        self.disable_jump_forward = server_args.disable_jump_forward
104
105
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
106
        self.enable_overlap = not server_args.disable_overlap_schedule
107
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
        self.enable_metrics = server_args.enable_metrics
109

110
111
112
        # Session info
        self.sessions = {}

113
114
115
        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
116
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
117
118
119
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
120
121
122
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
123

124
125
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
126
127
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
128
129
130
                )
            else:
                # Send to the detokenizer
131
132
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
133
                )
134
        else:
135
            self.recv_from_tokenizer = None
136
137
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
138
139
140
141

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
142
            trust_remote_code=server_args.trust_remote_code,
143
            context_length=server_args.context_length,
144
145
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
146
        )
147
        self.is_generation = self.model_config.is_generation
148
149
150
151

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
152
            if self.model_config.is_multimodal:
153
154
155
156
157
158
159
160
161
162
163
164
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
165

166
167
168
169
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
170
171
172

        if self.enable_overlap:
            self.disable_jump_forward = True
173

174
        # Launch a tensor parallel worker
175
        if self.enable_overlap:
176
            TpWorkerClass = TpModelWorkerClient
177
178
        else:
            TpWorkerClass = TpModelWorker
179

180
        self.tp_worker = TpWorkerClass(
181
            server_args=server_args,
182
183
            gpu_id=gpu_id,
            tp_rank=tp_rank,
184
            dp_rank=dp_rank,
185
            nccl_port=port_args.nccl_port,
186
        )
187

188
        # Get token and memory info from the model worker
189
190
191
192
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
193
            self.max_req_len,
194
195
            self.max_req_input_len,
            self.random_seed,
196
            self.device,
197
198
199
200
201
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
202
203
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
204
        global_server_args_dict.update(worker_global_server_args_dict)
205
206
207
208
209
210
211
212
213
214
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

215
216
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
233
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
234
235
236

        # Init running status
        self.waiting_queue: List[Req] = []
237
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
238
        self.running_batch: Optional[ScheduleBatch] = None
239
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
240
        self.cur_batch: Optional[ScheduleBatch] = None
241
242
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
        self.forward_ct = 0
        self.forward_ct_decode = 0
245
        self.num_generated_tokens = 0
246
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
247
        self.stream_interval = server_args.stream_interval
248
249
250

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
251
        self.being_chunked_req = None
252
253
254
255
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
256
        # Init the grammar backend for constrained generation
257
        self.grammar_queue: List[Req] = []
258
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
263
264
265
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
266
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
283
284

        # Init new token estimation
285
286
287
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
288
289
290

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
291
292
            * server_args.schedule_conservativeness,
            1.0,
293
        )
294
295
296
297
298
299
300
301
302
303
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
307
        self.batch_is_full = False
308

Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
312
313
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()

314
        # Init profiler
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
330

331
        # Init metrics stats
332
333
334
335
336
337
338
339
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
340

Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    def watchdog_thread(self):
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

        kill_parent_process()

358
    @torch.no_grad()
359
    def event_loop_normal(self):
360
        """A normal scheduler loop."""
361
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
362
363
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
364

365
            batch = self.get_next_batch_to_run()
Ke Bao's avatar
Ke Bao committed
366
367
368
            if self.server_args.enable_dp_attention:
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
369
            self.cur_batch = batch
370
371
372
373

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
374
            else:
375
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
376
                self.check_memory()
377
                self.new_token_ratio = self.init_new_token_ratio
378
379

            self.last_batch = batch
380

381
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
382
    def event_loop_overlap(self):
383
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
386
387
388
389
390
391
392
393
394
395
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

396
397
398
399
400
401
402
403
404
405
                if self.last_batch is None:
                    # A dummy first batch to start the pipeline for overlap scheduler.
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
408
409
410
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
413
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
414
                self.check_memory()
415
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
416
417
418

            self.last_batch = batch

Ke Bao's avatar
Ke Bao committed
419
420
421
422
423
424
425
426
427
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

428
429
        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
Ke Bao's avatar
Ke Bao committed
430
431
432
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
433
            group=self.tp_cpu_group,
Ke Bao's avatar
Ke Bao committed
434
435
436
437
438
439
440
441
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

Ke Bao's avatar
Ke Bao committed
460
461
462
463
464
465
466
467
468
        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
469
            self.enable_overlap,
Ke Bao's avatar
Ke Bao committed
470
471
472
473
        )
        idle_batch.prepare_for_idle()
        return idle_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
474
    def recv_requests(self):
Ke Bao's avatar
Ke Bao committed
475
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
478
479
480
481
482
483
484
485
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
486

Ke Bao's avatar
Ke Bao committed
487
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
488
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
489
490
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
491
    def process_input_requests(self, recv_reqs: List):
492
493
494
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
495
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
496
497
498
499
500
501
502
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
503
                self.send_to_tokenizer.send_pyobj(
504
505
                    UpdateWeightReqOutput(success, message)
                )
506
507
508
509
510
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
511
512
513
514
515
            elif isinstance(recv_req, OpenSessionReqInput):
                session_id = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
516
            elif isinstance(recv_req, GetMemPoolSizeReq):
517
                self.send_to_tokenizer.send_pyobj(
518
519
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
520
521
522
523
524
525
526
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
        if recv_req.session_id is None or recv_req.session_id not in self.sessions:
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
                lora_path=recv_req.lora_path,
            )
            req.tokenizer = self.tokenizer
            if recv_req.session_id is not None:
                req.finished_reason = FINISH_ABORT(
                    f"Invalid request: session id {recv_req.session_id} does not exist"
                )
                self.waiting_queue.append(req)
                return
        else:
            # Handle sessions
            session = self.sessions[recv_req.session_id]
545
            req = session.create_req(recv_req, self.tokenizer)
546
547
548
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
549
550
551
552
553
554

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
555
            req.origin_input_ids = self.pad_input_ids_func(
556
557
558
                req.origin_input_ids_unpadded, req.image_inputs
            )

559
560
561
562
563
564
565
566
567
            if len(req.origin_input_ids) > self.max_req_input_len:
                req.finished_reason = FINISH_ABORT(
                    "Image request length is longer than the KV cache pool size or "
                    "the max context length aborting because you cannot truncate the image embeds"
                )
                req.sampling_params.max_new_tokens = 0
                self.waiting_queue.append(req)
                return

568
569
570
571
572
573
574
575
576
577
        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

        # Truncate prompts that are too long
578
        if len(req.origin_input_ids) > self.max_req_input_len:
579
580
581
582
583
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
584

585
586
587
588
589
590
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
591
            self.max_req_len - len(req.origin_input_ids) - 1,
592
593
        )

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
612
613
614
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
615
616
617

    def handle_embedding_request(
        self,
618
        recv_req: TokenizedEmbeddingReqInput,
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
        if isinstance(self.tree_cache, RadixCache):
            self.tree_cache_metrics["total"] += (
                adder.log_input_tokens + adder.log_hit_tokens
            ) / 10**9
            self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
        else:
            tree_cache_hit_rate = 0.0

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
            f"#queue-req: {len(self.waiting_queue) + has_inflight}"
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
676
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
677
678
679
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
680
        self.num_generated_tokens = 0
681
        self.last_decode_stats_tic = time.time()
682
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
683
684
        logger.info(
            f"Decode batch. "
685
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
686
687
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
688
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
689
690
691
            f"#queue-req: {len(self.waiting_queue)}"
        )

692
693
694
695
696
697
698
699
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
700
701
702
703
704
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
705
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
706
                "KV cache pool leak detected!"
707
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
708
            )
709
710
711
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
712
713

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
714
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
715
                "Memory pool leak detected!"
716
717
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
718
            )
719
720
721
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
722

723
    def get_next_batch_to_run(self):
724
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
725
        if self.last_batch and self.last_batch.forward_mode.is_extend():
726
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
727
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
728
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
729
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
730
                # Inflight request keeps its rid but will get a new req_pool_idx
731
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
732
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
733

734
735
736
737
738
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
739

Lianmin Zheng's avatar
Lianmin Zheng committed
740
        # Run prefill first if possible
741
742
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
743
            return new_batch
744

745
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
746
        if self.running_batch is None:
747
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
748
        self.running_batch = self.update_running_batch(self.running_batch)
749
        return self.running_batch
750

Lianmin Zheng's avatar
Lianmin Zheng committed
751
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
752
        # Check if the grammar is ready in the grammar queue
753
        if self.grammar_queue:
754
            self.move_ready_grammar_requests()
755

Lianmin Zheng's avatar
Lianmin Zheng committed
756
757
758
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
759
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
762
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
763
        if running_bs >= self.max_running_requests:
764
            self.batch_is_full = True
765
766
767
768
769
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
770
        # Prefill policy
771
772
773
774
775
776
777
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
778
            running_bs if self.is_mixed_chunk else 0,
779
780
        )

781
        has_inflight = self.being_chunked_req is not None
782
        if has_inflight:
783
            self.being_chunked_req.init_next_round_input()
Chayenne's avatar
Chayenne committed
784
            self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
785

Lianmin Zheng's avatar
Lianmin Zheng committed
786
        if self.lora_paths:
787
788
789
790
791
792
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

793
        # Get requests from the waiting queue to a new prefill batch
794
795
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
796
                self.lora_paths
797
798
799
800
801
802
803
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
804
                self.batch_is_full = True
805
806
                break

807
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
808
                self.batch_is_full = True
809
                break
810

811
812
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
813
814
815
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
816
817
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
818
        # Update waiting queue
819
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
820
821
822
823
824
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
825

826
        if adder.new_inflight_req is not None:
827
828
            assert self.being_chunked_req is None
            self.being_chunked_req = adder.new_inflight_req
829

830
831
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
832

833
834
        # Print stats
        if self.tp_rank == 0:
835
            self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
836

Lianmin Zheng's avatar
Lianmin Zheng committed
837
        # Create a new batch
838
839
840
841
842
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
843
            self.model_config,
844
            self.enable_overlap,
845
        )
846
        new_batch.prepare_for_extend()
847

Lianmin Zheng's avatar
Lianmin Zheng committed
848
        # Mixed-style chunked prefill
849
850
851
852
853
854
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
855
856
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
857
                self.running_batch.prepare_for_decode()
858
859
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
860
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
861
862
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
863
864
865

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
866
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
867
        """Update the current running decoding batch."""
868
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
869
870

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
871

872
873
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
874
875
            self.batch_is_full = False
            return None
876

Lianmin Zheng's avatar
Lianmin Zheng committed
877
        # Check if decode out of memory
878
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
879
880
881
882
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
883

Lianmin Zheng's avatar
Lianmin Zheng committed
884
885
886
887
888
889
890
891
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
892
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
893
894
895
896
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
897
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
898
899
900
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
901
902
903
904
905
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907

        # Update batch tensors
908
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
909
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
910
911

    def run_batch(self, batch: ScheduleBatch):
912
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
913
914
        self.forward_ct += 1

915
        if self.is_generation:
916
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
917
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
918
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
919
                    model_worker_batch
920
                )
Ke Bao's avatar
Ke Bao committed
921
922
923
924
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
925
926
            else:
                logits_output = None
927
                if self.skip_tokenizer_init:
928
929
930
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
931
                else:
932
                    next_token_ids = torch.full((batch.batch_size(),), 0)
933
            batch.output_ids = next_token_ids
934
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
935
936
937
938
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
939
            ret = embeddings, model_worker_batch.bid
940
        return ret
Chayenne's avatar
Chayenne committed
941

Lianmin Zheng's avatar
Lianmin Zheng committed
942
943
944
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
945
946
            if batch.is_empty():
                self.running_batch = None
947
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
948
            self.process_batch_result_prefill(batch, result)
949
950
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
951
            torch.cuda.current_stream().synchronize()
952
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
953
954

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
955

Lianmin Zheng's avatar
Lianmin Zheng committed
956
        if self.is_generation:
957
            logits_output, next_token_ids, bid = result
958
959

            if self.enable_overlap:
960
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
961
962
963
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
964
965
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
966
                            torch.arange(len(next_token_ids), device=self.device),
967
968
969
970
971
972
973
974
975
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
976
                next_token_ids = next_token_ids.tolist()
977
978
979

            # Check finish conditions
            logprob_pt = 0
980
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
981
982
983
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
984
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
985
986
987
988
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
989

990
                if req.is_being_chunked <= 0:
991
                    req.completion_tokens_wo_jump_forward += 1
992
                    req.output_ids.append(next_token_id)
993
994
                    req.check_finished()

995
                    if req.finished():
996
                        self.tree_cache.cache_finished_req(req)
997
998
999
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1000
1001
1002
1003
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1004
1005
1006

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1007
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1008
                    # Inflight reqs' prefill is not finished
1009
1010
                    req.is_being_chunked -= 1

1011
1012
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1013
                torch.cuda.current_stream().synchronize()
1014
1015
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1016
        else:  # embedding or reward model
1017
1018
            embeddings, bid = result
            embeddings = embeddings.tolist()
1019
1020
1021

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1022
1023
1024
                if req.is_retracted:
                    continue

1025
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1026
1027
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1028
1029
1030
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1031
1032
1033
1034
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1035
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
1037
                    # Inflight reqs' prefill is not finished
                    req.is_being_chunked -= 1
1038

1039
        self.stream_output(batch.reqs)
1040

Lianmin Zheng's avatar
Lianmin Zheng committed
1041
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1042
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1043
1044
        self.num_generated_tokens += len(batch.reqs)

1045
        if self.enable_overlap:
1046
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1047
            next_token_logprobs = logits_output.next_token_logprobs
1048
1049
1050
1051
1052
1053
1054
1055
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1056

1057
1058
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1059
1060
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1061
1062
1063
            if req.is_retracted:
                continue

1064
            if self.enable_overlap and req.finished():
1065
                # Free the one delayed token
1066
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1067
1068
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1069
1070
1071
1072
1073
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

            if req.finished():
1074
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1075
1076
1077
1078
1079
1080
1081
1082

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

Lianmin Zheng's avatar
Lianmin Zheng committed
1083
1084
1085
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)

1086
1087
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1088
            torch.cuda.current_stream().synchronize()
1089
1090
            batch.next_batch_sampling_info.sampling_info_done.set()

1091
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1092

1093
1094
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1095
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1096
1097
1098
1099
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1100
            self.log_decode_stats()
1101

1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

1175
    def stream_output(self, reqs: List[Req]):
1176
        """Stream the output to detokenizer."""
1177
        output_rids = []
1178
        output_meta_info: List[dict] = []
1179
1180
1181
1182
1183
1184
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1185
            output_ids = []
1186
1187
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1188
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1189
        else:  # embedding or reward model
1190
1191
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1192
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1193

1194
        for req in reqs:
1195
            # TODO(lianmin): revisit this for overlap + retract + stream
1196
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1197
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1198
1199
1200
1201
1202
1203
1204
1205
1206
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1207
1208
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1209
1210
1211
1212
1213
1214
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1215
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1216
1217
1218
1219
1220

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1221
                        "cached_tokens": req.cached_tokens,
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1243
                else:  # embedding or reward model
1244
1245
1246
1247
1248
1249
1250
1251
1252
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1253
                self.send_to_detokenizer.send_pyobj(
1254
1255
1256
1257
1258
1259
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1260
                        output_ids,
1261
1262
1263
1264
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1265
                        output_no_stop_trim,
1266
1267
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1268
            else:  # embedding or reward model
1269
                self.send_to_detokenizer.send_pyobj(
1270
1271
1272
1273
1274
1275
1276
1277
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1302
    def flush_cache(self):
1303
        """Flush the memory pool and cache."""
1304
1305
1306
1307
1308
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1309
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1310
                self.grammar_backend.reset()
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1339
                if req.rid == recv_req.rid and not req.finished():
1340
                    req.finished_reason = FINISH_ABORT()
1341
                    self.tree_cache.cache_finished_req(req)
1342
1343
1344
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1345
        """In-place update of the weights."""
1346
1347
1348
1349
1350
1351
1352
1353
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
    def open_session(self, recv_req: OpenSessionReqInput) -> str:
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
        return session_id

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1387
1388
1389
1390
1391
1392

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1393
    dp_rank: Optional[int],
1394
    pipe_writer,
1395
):
1396
    # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
Lianmin Zheng's avatar
Lianmin Zheng committed
1397
    if dp_rank is None and "DP_RANK" in os.environ:
Byron Hsu's avatar
Byron Hsu committed
1398
        dp_rank = int(os.environ["DP_RANK"])
1399

1400
1401
1402
1403
1404
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1405
    suppress_other_loggers()
1406
1407

    try:
1408
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1409
1410
1411
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1412
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1413
1414
1415
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1416
1417
1418
1419
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()