scheduler.py 49.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A scheduler that manages a tensor parallel GPU worker."""

import logging
19
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import threading
21
22
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from types import SimpleNamespace
26
from typing import List, Optional
27

28
import torch
29
30
import zmq

31
32
33
34
35
36
37
38
39
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
    FlushCacheReq,
40
41
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
42
    ProfileReq,
43
44
45
46
47
48
49
50
51
52
53
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
54
    global_server_args_dict,
55
)
56
57
58
59
60
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
61
from sglang.srt.managers.tp_worker import TpModelWorker
62
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
64
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
65
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
66
from sglang.srt.server_args import PortArgs, ServerArgs
67
68
69
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
70
    crash_on_warnings,
71
    get_zmq_socket,
72
73
74
75
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
76
77
78
79
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

80
81
82
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"

83
84
85
86
87
88
89
90
91
92

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
93
        dp_rank: Optional[int],
94
95
    ):
        # Parse args
96
        self.server_args = server_args
97
98
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
99
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        self.disable_jump_forward = server_args.disable_jump_forward
101
102
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
103
        self.enable_overlap = server_args.enable_overlap_schedule
104
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
105
        self.enable_metrics = server_args.enable_metrics
106
107
108
109

        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
110
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
111
112
113
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
114
115
116
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
117

118
119
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
120
121
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
122
123
124
                )
            else:
                # Send to the detokenizer
125
126
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
127
                )
128
        else:
129
            self.recv_from_tokenizer = None
130
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
131
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
132
133
134
135

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
136
            trust_remote_code=server_args.trust_remote_code,
137
            context_length=server_args.context_length,
138
139
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
140
        )
141
        self.is_generation = self.model_config.is_generation
142
143
144
145

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
146
            if self.model_config.is_multimodal:
147
148
149
150
151
152
153
154
155
156
157
158
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
159

160
        # Launch a tensor parallel worker
161
        if self.enable_overlap:
162
            TpWorkerClass = TpModelWorkerClient
163
164
        else:
            TpWorkerClass = TpModelWorker
165

166
        self.tp_worker = TpWorkerClass(
167
            server_args=server_args,
168
169
            gpu_id=gpu_id,
            tp_rank=tp_rank,
170
            dp_rank=dp_rank,
171
            nccl_port=port_args.nccl_port,
172
        )
173

174
        # Get token and memory info from the model worker
175
176
177
178
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
179
            self.max_req_len,
180
181
            self.max_req_input_len,
            self.random_seed,
182
            self.device,
183
184
185
186
187
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
188
189
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
190
        global_server_args_dict.update(worker_global_server_args_dict)
191
192
193
194
195
196
197
198
199
200
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

201
202
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
219
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
220
221
222

        # Init running status
        self.waiting_queue: List[Req] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
223
        self.running_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
224
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
        self.forward_ct = 0
        self.forward_ct_decode = 0
227
        self.num_generated_tokens = 0
228
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
229
        self.stream_interval = server_args.stream_interval
230
231
232

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
233
        self.being_chunked_req = None
234
235
236
237
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
238
        # Init the grammar backend for constrained generation
239
        self.grammar_queue: List[Req] = []
240
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
243
244
245
246
247
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
248
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
265
266

        # Init new token estimation
267
268
269
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
270
271
272

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
273
274
            * server_args.schedule_conservativeness,
            1.0,
275
        )
276
277
278
279
280
281
282
283
284
285
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

286
        self.batch_is_full = False
287

Lianmin Zheng's avatar
Lianmin Zheng committed
288
289
290
291
292
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()

293
        # Init profiler
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
309

310
        # Init metrics stats
311
312
313
314
315
316
317
318
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
319

Lianmin Zheng's avatar
Lianmin Zheng committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def watchdog_thread(self):
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

        kill_parent_process()

337
    @torch.no_grad()
338
    def event_loop_normal(self):
339
        """A normal blocking scheduler loop."""
340
341
        self.last_batch = None

342
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
343
344
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
345

346
            batch = self.get_next_batch_to_run()
Ke Bao's avatar
Ke Bao committed
347
348
349
350

            if self.server_args.enable_dp_attention:
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
351
            self.cur_batch = batch
352
353
354
355

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
356

357
358
359
360
361
362
363
364
                # Decode multiple steps to reduce the overhead
                if batch.forward_mode.is_decode():
                    for _ in range(self.server_args.num_continuous_decode_steps - 1):
                        if not self.running_batch:
                            break
                        self.update_running_batch()
                        if not self.running_batch:
                            break
Ke Bao's avatar
Ke Bao committed
365
366
                        if self.server_args.enable_dp_attention:
                            batch = self.prepare_dp_attn_batch(batch)
367
368
                        result = self.run_batch(batch)
                        self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
            else:
                self.check_memory()
371
                self.new_token_ratio = self.init_new_token_ratio
372
373

            self.last_batch = batch
374

375
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
376
    def event_loop_overlap(self):
377
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
380
381
382
383
384
385
386
387
388
389
        result_queue = deque()

        self.last_batch = None
        self.running_batch = None

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
390
391
392
                # We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
                _ = batch.seq_lens[0].item()

Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
395
396
397
398
399
400
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
                self.check_memory()
401
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
402
403
404

            self.last_batch = batch

Ke Bao's avatar
Ke Bao committed
405
406
407
408
409
410
411
412
413
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

414
415
        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
Ke Bao's avatar
Ke Bao committed
416
417
418
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
419
            group=self.tp_cpu_group,
Ke Bao's avatar
Ke Bao committed
420
421
422
423
424
425
426
427
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

Ke Bao's avatar
Ke Bao committed
446
447
448
449
450
451
452
453
454
455
456
457
458
        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
        )
        idle_batch.prepare_for_idle()
        return idle_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
459
    def recv_requests(self):
Ke Bao's avatar
Ke Bao committed
460
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
461
462
463
464
465
466
467
468
469
470
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
471

Ke Bao's avatar
Ke Bao committed
472
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
473
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
474
475
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
476
    def process_input_requests(self, recv_reqs: List):
477
478
479
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
480
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
481
482
483
484
485
486
487
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
488
                self.send_to_tokenizer.send_pyobj(
489
490
                    UpdateWeightReqOutput(success, message)
                )
491
492
493
494
495
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
496
            elif isinstance(recv_req, GetMemPoolSizeReq):
497
                self.send_to_tokenizer.send_pyobj(
498
499
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
            lora_path=recv_req.lora_path,
        )
        req.tokenizer = self.tokenizer

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
521
            req.origin_input_ids = self.pad_input_ids_func(
522
523
524
525
526
527
528
529
530
531
532
533
534
                req.origin_input_ids_unpadded, req.image_inputs
            )

        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

        # Truncate prompts that are too long
535
        if len(req.origin_input_ids) > self.max_req_input_len:
536
537
538
539
540
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
541

542
543
544
545
546
547
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
548
            self.max_req_len - len(req.origin_input_ids) - 1,
549
550
        )

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
569
570
571
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
572
573
574

    def handle_embedding_request(
        self,
575
        recv_req: TokenizedEmbeddingReqInput,
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
        if isinstance(self.tree_cache, RadixCache):
            self.tree_cache_metrics["total"] += (
                adder.log_input_tokens + adder.log_hit_tokens
            ) / 10**9
            self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
        else:
            tree_cache_hit_rate = 0.0

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
            f"#queue-req: {len(self.waiting_queue) + has_inflight}"
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
631
632
633
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
634
635
636
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
637
        self.num_generated_tokens = 0
638
        self.last_decode_stats_tic = time.time()
639
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
640
641
        logger.info(
            f"Decode batch. "
642
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
645
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
646
647
648
            f"#queue-req: {len(self.waiting_queue)}"
        )

649
650
651
652
653
654
655
656
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
659
660
661
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
662
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
663
                "KV cache pool leak detected!"
664
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
665
            )
666
667
668
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
669
670

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
671
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
672
                "Memory pool leak detected!"
673
674
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
675
            )
676
677
678
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
679

680
    def get_next_batch_to_run(self):
681
        # Merge the prefill batch into the running batch
682
683
684
685
686
        if (
            self.last_batch
            and not self.last_batch.forward_mode.is_decode()
            and not self.last_batch.is_empty()
        ):
687
            if self.being_chunked_req:
Chayenne's avatar
Chayenne committed
688
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
689
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
690
                # Inflight request keeps its rid but will get a new req_pool_idx.
691
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
692
693
694
695
696
697
                self.batch_is_full = False
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
698
699

        # Prefill first
700
701
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
702
            return new_batch
703

704
705
706
707
708
709
710
711
712
713
714
715
716
        # Check memory
        if self.running_batch is None:
            return

        # Run decode
        before_bs = self.running_batch.batch_size()
        self.update_running_batch()
        if not self.running_batch:
            self.batch_is_full = False
            return None
        if before_bs != self.running_batch.batch_size():
            self.batch_is_full = False
        return self.running_batch
717

Lianmin Zheng's avatar
Lianmin Zheng committed
718
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
719
        # Check if the grammar is ready in the grammar queue
720
        if self.grammar_queue:
721
            self.move_ready_grammar_requests()
722

Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
725
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
726
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
729
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
730
        if running_bs >= self.max_running_requests:
731
            self.batch_is_full = True
732
733
734
735
736
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
737
        # Prefill policy
738
739
740
741
742
743
744
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
745
            running_bs if self.is_mixed_chunk else 0,
746
747
        )

748
        has_inflight = self.being_chunked_req is not None
749
        if has_inflight:
750
            self.being_chunked_req.init_next_round_input()
Chayenne's avatar
Chayenne committed
751
            self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
752

Lianmin Zheng's avatar
Lianmin Zheng committed
753
        if self.lora_paths:
754
755
756
757
758
759
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

760
        # Get requests from the waiting queue to a new prefill batch
761
762
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
763
                self.lora_paths
764
765
766
767
768
769
770
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
771
                self.batch_is_full = True
772
773
                break

774
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
775
                self.batch_is_full = True
776
                break
777

778
779
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
780
781
782
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
783
784
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
785
        # Update waiting queue
786
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
787
788
789
790
791
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
792

793
        if adder.new_inflight_req is not None:
794
795
            assert self.being_chunked_req is None
            self.being_chunked_req = adder.new_inflight_req
796

797
798
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
799

800
801
        # Print stats
        if self.tp_rank == 0:
802
            self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
803

Lianmin Zheng's avatar
Lianmin Zheng committed
804
        # Create a new batch
805
806
807
808
809
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
810
            self.model_config,
811
        )
812
        new_batch.prepare_for_extend()
813

Lianmin Zheng's avatar
Lianmin Zheng committed
814
        # Mixed-style chunked prefill
815
        if self.is_mixed_chunk and self.running_batch is not None:
816
817
818
819
820
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
                self.running_batch.prepare_for_decode(self.enable_overlap)
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
821
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
822
823
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
824
825
826

        return new_batch

827
    def update_running_batch(self):
828
        """Update the current running decoding batch."""
829
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
830
831
        batch = self.running_batch

832
833
834
835
836
        batch.filter_batch()
        if batch.is_empty():
            self.running_batch = None
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
837
        # Check if decode out of memory
838
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
839
840
841
842
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
843

Lianmin Zheng's avatar
Lianmin Zheng committed
844
845
846
847
848
849
850
851
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
852
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
853
854
855
856
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
857
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
858
859
860
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
861
862
                self.running_batch = None
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
863
864

        # Update batch tensors
865
        batch.prepare_for_decode(self.enable_overlap)
Lianmin Zheng's avatar
Lianmin Zheng committed
866
867

    def run_batch(self, batch: ScheduleBatch):
868
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
869
870
        self.forward_ct += 1

871
        if self.is_generation:
872
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
873
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
874
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
875
                    model_worker_batch
876
                )
Ke Bao's avatar
Ke Bao committed
877
878
879
880
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
881
882
            else:
                logits_output = None
883
                if self.skip_tokenizer_init:
884
885
886
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
887
                else:
888
                    next_token_ids = torch.full((batch.batch_size(),), 0)
889
            batch.output_ids = next_token_ids
890
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
891
892
893
894
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
895
            ret = embeddings, model_worker_batch.bid
896
        return ret
Chayenne's avatar
Chayenne committed
897

Lianmin Zheng's avatar
Lianmin Zheng committed
898
    def process_batch_result(self, batch: ScheduleBatch, result):
Ke Bao's avatar
Ke Bao committed
899
900
        if batch.forward_mode.is_idle():
            return
Lianmin Zheng's avatar
Lianmin Zheng committed
901
902
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
903
904
            if batch.is_empty():
                self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
905
906
907
908
        else:
            self.process_batch_result_prefill(batch, result)

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
909

Lianmin Zheng's avatar
Lianmin Zheng committed
910
        if self.is_generation:
911
            logits_output, next_token_ids, bid = result
912
913

            if self.enable_overlap:
914
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
915
916
917
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
918
919
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
920
                            torch.arange(len(next_token_ids), device=self.device),
921
922
923
924
925
926
927
928
929
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
930
                next_token_ids = next_token_ids.tolist()
931
932
933
934

            # Check finish conditions
            logprob_pt = 0
            for i, req in enumerate(batch.reqs):
935
936
937
938
                if req.is_retracted:
                    continue

                if req.is_being_chunked <= 0:
939
                    # Inflight reqs' prefill is not finished
940
941
942
943
                    req.completion_tokens_wo_jump_forward += 1
                    req.output_ids.append(next_token_ids[i])
                    req.check_finished()

944
                    if req.finished():
945
                        self.tree_cache.cache_finished_req(req)
946
947
948
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

949
950
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_ids[i])
951
952
953
954
955

                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
956
957
958
                else:
                    req.is_being_chunked -= 1

Lianmin Zheng's avatar
Lianmin Zheng committed
959
        else:  # embedding or reward model
960
961
            embeddings, bid = result
            embeddings = embeddings.tolist()
962
963
964

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
965
966
967
                if req.is_retracted:
                    continue

968
                req.embedding = embeddings[i]
969
970
                if req.is_being_chunked > 0:
                    req.is_being_chunked -= 1
971
972
                else:
                    # Inflight reqs' prefill is not finished
973
974
975
976
977
                    # dummy output token for embedding models
                    req.output_ids.append(0)
                    req.check_finished()

                if req.finished():
978
                    self.tree_cache.cache_finished_req(req)
979
980
981
                else:
                    self.tree_cache.cache_unfinished_req(req)

982
        self.stream_output(batch.reqs)
983

Lianmin Zheng's avatar
Lianmin Zheng committed
984
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
985
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
986
987
        self.num_generated_tokens += len(batch.reqs)

988
        if self.enable_overlap:
989
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
990
            next_token_logprobs = logits_output.next_token_logprobs
991
992
993
994
995
996
997
998
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
999

1000
1001
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1002
1003
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1004
1005
1006
            if req.is_retracted:
                continue

1007
            if self.enable_overlap and req.finished():
1008
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1009
1010
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1011
1012
1013
1014
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

1015
1016
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
1018

            if req.finished():
1019
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021
1022
1023
1024
1025
1026
1027

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

1028
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1029

1030
1031
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1032
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1033
1034
1035
1036
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1037
            self.log_decode_stats()
1038

1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

1112
    def stream_output(self, reqs: List[Req]):
1113
        """Stream the output to detokenizer."""
1114
        output_rids = []
1115
        output_meta_info: List[dict] = []
1116
1117
1118
1119
1120
1121
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1122
            output_ids = []
1123
1124
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1125
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
        else:  # embedding or reward model
1127
1128
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1129
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1130

1131
        for req in reqs:
1132
            # TODO(lianmin): revisit this for overlap + retract + stream
1133
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1135
1136
1137
1138
1139
1140
1141
1142
1143
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1144
1145
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1146
1147
1148
1149
1150
1151
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1152
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1153
1154
1155
1156
1157

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1158
                        "cached_tokens": req.cached_tokens,
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1180
                else:  # embedding or reward model
1181
1182
1183
1184
1185
1186
1187
1188
1189
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1190
                self.send_to_detokenizer.send_pyobj(
1191
1192
1193
1194
1195
1196
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1197
                        output_ids,
1198
1199
1200
1201
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1202
                        output_no_stop_trim,
1203
1204
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1205
            else:  # embedding or reward model
1206
                self.send_to_detokenizer.send_pyobj(
1207
1208
1209
1210
1211
1212
1213
1214
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1239
    def flush_cache(self):
1240
        """Flush the memory pool and cache."""
1241
1242
1243
1244
1245
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1246
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1247
                self.grammar_backend.reset()
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1276
                if req.rid == recv_req.rid and not req.finished():
1277
                    req.finished_reason = FINISH_ABORT()
1278
                    self.tree_cache.cache_finished_req(req)
1279
1280
1281
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1282
        """In-place update of the weights."""
1283
1284
1285
1286
1287
1288
1289
1290
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1305
1306
1307
1308
1309
1310

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1311
    dp_rank: Optional[int],
1312
    pipe_writer,
1313
):
1314
1315
1316
1317
1318
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1319
    suppress_other_loggers()
1320
1321

    try:
1322
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1323
        pipe_writer.send("ready")
1324
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1325
1326
1327
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1328
1329
1330
1331
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()