scheduler.py 42.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A scheduler that manages a tensor parallel GPU worker."""

18
import json
19
import logging
20
21
22
import os
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
24
from types import SimpleNamespace
25
from typing import List, Optional, Union
26

27
import torch
28
29
import zmq

30
31
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
32
from sglang.srt.constrained.grammar import GrammarCache
33
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
    FlushCacheReq,
40
41
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
42
    ProfileReq,
43
44
45
46
47
48
49
50
51
52
53
54
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    TokenizedRewardReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
55
    global_server_args_dict,
56
)
57
58
59
60
61
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
62
from sglang.srt.managers.tp_worker import TpModelWorker
63
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
64
65
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
66
from sglang.srt.server_args import PortArgs, ServerArgs
67
68
69
70
71
72
73
74
75
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
    is_generation_model,
    is_multimodal_model,
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
76
77
78
79
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

80
81
82
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"

83
84
85
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"

86
87
88
89
90
91
92
93
94
95

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
96
        dp_rank: Optional[int],
97
98
    ):
        # Parse args
99
        self.server_args = server_args
100
101
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
102
103
104
105
        self.schedule_policy = server_args.schedule_policy
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
106
        self.enable_overlap = server_args.enable_overlap_schedule
107
108
109
110
111
112

        # Init inter-process communication
        context = zmq.Context(2)

        if self.tp_rank == 0:
            self.recv_from_tokenizer = context.socket(zmq.PULL)
113
            self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
114
115

            self.send_to_detokenizer = context.socket(zmq.PUSH)
116
            self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
117
        else:
118
119
            self.recv_from_tokenizer = None
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
            model_override_args=json.loads(server_args.json_model_override_args),
        )

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if is_multimodal_model(self.model_config.hf_config.architectures):
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
        self.is_generation = is_generation_model(
            self.model_config.hf_config.architectures, self.server_args.is_embedding
        )
148

149
        # Launch a tensor parallel worker
150
        if self.enable_overlap:
151
            TpWorkerClass = TpModelWorkerClient
152
153
        else:
            TpWorkerClass = TpModelWorker
154

155
        self.tp_worker = TpWorkerClass(
156
            server_args=server_args,
157
158
            gpu_id=gpu_id,
            tp_rank=tp_rank,
159
            dp_rank=dp_rank,
160
            nccl_port=port_args.nccl_port,
161
        )
162

163
        # Get token and memory info from the model worker
164
165
166
167
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
168
            self.max_req_len,
169
170
            self.max_req_input_len,
            self.random_seed,
171
            self.device,
172
173
174
175
176
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
177
178
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
179
        global_server_args_dict.update(worker_global_server_args_dict)
180
181
182
183
184
185
186
187
188
189
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

190
191
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
208
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
209
210
211

        # Init running status
        self.waiting_queue: List[Req] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
212
        self.running_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
213
        self.cur_batch: Optional[ScheduleBatch] = None
214
215
216
217
218
219
220
        self.decode_forward_ct = 0
        self.stream_interval = server_args.stream_interval
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
221
        self.current_inflight_req = None
222
223
224
225
226
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

        # Init the FSM cache for constrained generation
227
228
        self.grammar_cache = None

229
        if not server_args.skip_tokenizer_init:
230
            self.grammar_cache = GrammarCache(
231
232
233
234
235
236
                server_args.tokenizer_path,
                {
                    "tokenizer_mode": server_args.tokenizer_mode,
                    "trust_remote_code": server_args.trust_remote_code,
                },
                skip_tokenizer_init=server_args.skip_tokenizer_init,
237
238
239
                whitespace_patterns=server_args.constrained_json_whitespace_pattern,
                backend=server_args.grammar_backend,
                allow_jump=not server_args.disable_regex_jump_forward,
240
            )
241
242

        # Init new token estimation
243
244
245
246
247
248
249
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
        self.min_new_token_ratio = min(
            global_config.base_min_new_token_ratio
            * server_args.schedule_conservativeness,
            1.0,
250
        )
251
252
        self.new_token_ratio = self.min_new_token_ratio
        self.new_token_ratio_decay = global_config.new_token_ratio_decay
253
        self.batch_is_full = False
254

255
        # Init profiler
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
272
    @torch.inference_mode()
273
    def event_loop_normal(self):
274
        """A normal blocking scheduler loop."""
275
276
        self.last_batch = None

277
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
278
279
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
280

281
282
283
284
285
            batch = self.get_next_batch_to_run()

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
286

287
288
289
290
291
292
293
294
295
296
                # Decode multiple steps to reduce the overhead
                if batch.forward_mode.is_decode():
                    for _ in range(self.server_args.num_continuous_decode_steps - 1):
                        if not self.running_batch:
                            break
                        self.update_running_batch()
                        if not self.running_batch:
                            break
                        result = self.run_batch(batch)
                        self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
            else:
                self.check_memory()
299
                self.new_token_ratio = global_config.init_new_token_ratio
300
301

            self.last_batch = batch
302

Lianmin Zheng's avatar
Lianmin Zheng committed
303
304
    @torch.inference_mode()
    def event_loop_overlap(self):
305
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        result_queue = deque()

        self.last_batch = None
        self.running_batch = None

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
                self.check_memory()
326
                self.new_token_ratio = global_config.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
327
328
329

            self.last_batch = batch

Lianmin Zheng's avatar
Lianmin Zheng committed
330
331
332
333
334
335
336
337
338
339
340
341
    def recv_requests(self):
        if self.tp_rank == 0:
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
342

343
344
        if self.tp_size != 1:
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
345
346
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
347
    def process_input_requests(self, recv_reqs: List):
348
349
350
351
352
353
354
355
356
357
358
359
360
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
            elif isinstance(
                recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
            ):
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
361
362
363
                self.send_to_detokenizer.send_pyobj(
                    UpdateWeightReqOutput(success, message)
                )
364
365
366
367
368
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
369
370
371
372
            elif isinstance(recv_req, GetMemPoolSizeReq):
                self.send_to_detokenizer.send_pyobj(
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
            lora_path=recv_req.lora_path,
        )
        req.tokenizer = self.tokenizer

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
394
            req.origin_input_ids = self.pad_input_ids_func(
395
396
397
398
399
400
401
402
403
404
405
406
                req.origin_input_ids_unpadded, req.image_inputs
            )

        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

407
        # Init regex FSM or BNF
408
409
410
411
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
412
            assert self.grammar_cache is not None
413
            if req.sampling_params.json_schema is not None:
414
415
416
                req.grammar = self.grammar_cache.query(
                    ("json", req.sampling_params.json_schema),
                    self.model_config.vocab_size,
417
418
                )
            elif req.sampling_params.regex is not None:
419
420
                req.grammar = self.grammar_cache.query(
                    ("regex", req.sampling_params.regex), self.model_config.vocab_size
421
422
423
                )

        # Truncate prompts that are too long
424
        if len(req.origin_input_ids) > self.max_req_input_len:
425
426
427
428
429
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
430

431
432
433
434
435
436
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
437
            self.max_req_len - len(req.origin_input_ids) - 1,
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        )

        self.waiting_queue.append(req)

    def handle_embedding_request(
        self,
        recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
466
467
468
469
470
    def print_decode_stats(self):
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
471
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
472
473
        logger.info(
            f"Decode batch. "
474
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"gen throughput (token/s): {throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )

    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
            warnings.warn(
                "Warning: "
                f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
                "KV cache pool leak detected!"
            )
            exit(1) if crash_on_warning else None

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
            warnings.warn(
                "Warning: "
                f"available req slots={len(self.req_to_token_pool.free_slots)}, "
                f"total slots={self.req_to_token_pool.size}\n"
                "Memory pool leak detected!"
            )
            exit(1) if crash_on_warning else None

502
    def get_next_batch_to_run(self):
503
        # Merge the prefill batch into the running batch
504
505
506
507
508
        if (
            self.last_batch
            and not self.last_batch.forward_mode.is_decode()
            and not self.last_batch.is_empty()
        ):
509
510
511
512
513
514
515
            if self.current_inflight_req:
                self.last_batch.filter_batch(
                    current_inflight_req=self.current_inflight_req
                )
                self.tree_cache.cache_unfinished_req(self.current_inflight_req)
                # Inflight request keeps its rid but will get a new req_pool_idx.
                self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
516
517
518
519
520
521
                self.batch_is_full = False
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
522
523

        # Prefill first
524
525
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
526
            return new_batch
527

528
529
530
531
532
533
534
535
536
537
538
539
540
        # Check memory
        if self.running_batch is None:
            return

        # Run decode
        before_bs = self.running_batch.batch_size()
        self.update_running_batch()
        if not self.running_batch:
            self.batch_is_full = False
            return None
        if before_bs != self.running_batch.batch_size():
            self.batch_is_full = False
        return self.running_batch
541

Lianmin Zheng's avatar
Lianmin Zheng committed
542
543
544
545
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
546
        ) and self.current_inflight_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
547
548
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
549
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
550
        if running_bs >= self.max_running_requests:
551
            self.batch_is_full = True
552
553
554
555
556
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
557
        # Prefill policy
558
559
560
561
562
563
564
565
566
567
568
        num_mixed_running = running_bs if self.is_mixed_chunk else 0
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
            num_mixed_running,
        )

569
570
571
572
573
574
575
576
577
        has_inflight = self.current_inflight_req is not None
        if has_inflight:
            self.current_inflight_req.init_next_round_input(
                None if prefix_computed else self.tree_cache
            )
            self.current_inflight_req = adder.add_inflight_req(
                self.current_inflight_req
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
578
        if self.lora_paths:
579
580
581
582
583
584
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

585
        # Get requests from the waiting queue to a new prefill batch
586
587
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
588
                self.lora_paths
589
590
591
592
593
594
595
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
596
                self.batch_is_full = True
597
598
                break

599
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
600
                self.batch_is_full = True
601
                break
602

603
604
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
605
606
607
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
608
609
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
610
        # Update waiting queue
611
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
614
615
616
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
617

618
619
620
621
622
623
        if adder.new_inflight_req is not None:
            assert self.current_inflight_req is None
            self.current_inflight_req = adder.new_inflight_req

        if self.current_inflight_req:
            self.current_inflight_req.is_inflight_req += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
624

625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        # Print stats
        if self.tp_rank == 0:
            if isinstance(self.tree_cache, RadixCache):
                self.tree_cache_metrics["total"] += (
                    adder.log_input_tokens + adder.log_hit_tokens
                ) / 10**9
                self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
                tree_cache_hit_rate = (
                    self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
                )
            else:
                tree_cache_hit_rate = 0.0

            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )

            if num_mixed_running > 0:
                logger.info(
                    f"Prefill batch"
                    f"(mixed #running-req: {num_mixed_running}). "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
652
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
653
654
655
656
657
658
659
660
661
662
                )
            else:
                logger.info(
                    f"Prefill batch. "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
                    f"#running-req: {running_bs}, "
663
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
664
665
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
666
        # Create a new batch
667
668
669
670
671
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
672
            self.model_config,
673
        )
674
        new_batch.prepare_for_extend()
675

Lianmin Zheng's avatar
Lianmin Zheng committed
676
        # Mixed-style chunked prefill
677
        if self.is_mixed_chunk and self.running_batch is not None:
678
            self.running_batch.prepare_for_decode(self.enable_overlap)
Lianmin Zheng's avatar
Lianmin Zheng committed
679
            new_batch.mix_with_running(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
680
            new_batch.decoding_reqs = self.running_batch.reqs
681
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
682
683
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
684
685
686

        return new_batch

687
    def update_running_batch(self):
688
        """Update the current running decoding batch."""
689
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
690
691
        batch = self.running_batch

692
693
694
695
696
        batch.filter_batch()
        if batch.is_empty():
            self.running_batch = None
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
697
        # Check if decode out of memory
698
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
699
700
701
702
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
703

Lianmin Zheng's avatar
Lianmin Zheng committed
704
705
706
707
708
709
710
711
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
712
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
713
714
715
716
717
718
719
720
                self.min_new_token_ratio,
            )

        # Check for jump-forward
        if not self.disable_regex_jump_forward:
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
721
722
                self.running_batch = None
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
723
724

        # Update batch tensors
725
        batch.prepare_for_decode(self.enable_overlap)
Lianmin Zheng's avatar
Lianmin Zheng committed
726
727

    def run_batch(self, batch: ScheduleBatch):
728
        """Run a batch."""
729
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
730
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
731
                model_worker_batch = batch.get_model_worker_batch()
732
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
733
                    model_worker_batch
734
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
735
736
737
            else:
                logits_output = None
                if self.tokenizer is not None:
738
739
740
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
741
                else:
742
                    next_token_ids = torch.full((batch.batch_size(),), 0)
743
            batch.output_ids = next_token_ids
744
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
745
746
747
748
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
749
            ret = embeddings, model_worker_batch.bid
750
        return ret
Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
753
754

    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
755
756
            if batch.is_empty():
                self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
757
758
759
760
761
        else:
            self.process_batch_result_prefill(batch, result)

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
        if self.is_generation:
762
            logits_output, next_token_ids, bid = result
763
764
765
766
767
768

            if self.enable_overlap:
                logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
769
770
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
771
                            torch.arange(len(next_token_ids), device=self.device),
772
773
774
775
776
777
778
779
780
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
781
                next_token_ids = next_token_ids.tolist()
782
783
784
785

            # Check finish conditions
            logprob_pt = 0
            for i, req in enumerate(batch.reqs):
786
787
788
789
                if req.is_inflight_req > 0:
                    req.is_inflight_req -= 1
                else:
                    # Inflight reqs' prefill is not finished
790
791
792
793
                    req.completion_tokens_wo_jump_forward += 1
                    req.output_ids.append(next_token_ids[i])
                    req.check_finished()

794
                    if req.finished():
795
                        self.tree_cache.cache_finished_req(req)
796
797
798
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

799
800
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_ids[i])
801
802
803
804
805

                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
806
        else:  # embedding or reward model
807
808
            embeddings, bid = result
            embeddings = embeddings.tolist()
809
810
811
812

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
                req.embedding = embeddings[i]
813
814
815
816
                if req.is_inflight_req > 0:
                    req.is_inflight_req -= 1
                else:
                    # Inflight reqs' prefill is not finished
817
818
819
820
821
                    # dummy output token for embedding models
                    req.output_ids.append(0)
                    req.check_finished()

                if req.finished():
822
                    self.tree_cache.cache_finished_req(req)
823
824
825
                else:
                    self.tree_cache.cache_unfinished_req(req)

826
        self.stream_output(batch.reqs)
827

Lianmin Zheng's avatar
Lianmin Zheng committed
828
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
829
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
830
831
        self.num_generated_tokens += len(batch.reqs)

832
833
        if self.enable_overlap:
            logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
834
            next_token_logprobs = logits_output.next_token_logprobs
835
836
837
838
839
840
841
842
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
843

844
845
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
846
847
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
Lianmin Zheng's avatar
Lianmin Zheng committed
848
            if self.server_args.enable_overlap_schedule and req.finished():
849
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
850
851
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
852
853
854
855
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

856
857
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
858
859

            if req.finished():
860
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
861
862
863
864
865
866
867
868

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

869
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
870

871
872
        self.token_to_kv_pool.free_group_end()

873
874
875
876
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
        if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
            self.print_decode_stats()

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

950
    def stream_output(self, reqs: List[Req]):
951
        """Stream the output to detokenizer."""
952
953
954
955
956
957
958
959
960
961
        output_rids = []
        output_meta_info = []
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
962
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
963
        else:  # embedding or reward model
964
965
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
966
967
        is_stream_iter = self.decode_forward_ct % self.stream_interval == 0

968
        for req in reqs:
969
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
970
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
986
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
987
988
989
990
991

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
992
                        "cached_tokens": req.cached_tokens,
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1014
                else:  # embedding or reward model
1015
1016
1017
1018
1019
1020
1021
1022
1023
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1024
                self.send_to_detokenizer.send_pyobj(
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1035
                        output_no_stop_trim,
1036
1037
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1038
            else:  # embedding or reward model
1039
                self.send_to_detokenizer.send_pyobj(
1040
1041
1042
1043
1044
1045
1046
1047
1048
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

    def flush_cache(self):
1049
        """Flush the memory pool and cache."""
1050
1051
1052
1053
1054
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1055
1056
1057
            if self.grammar_cache is not None:
                self.grammar_cache.reset()
            # TODO(dark): reset the bnf cache
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1086
                if req.rid == recv_req.rid and not req.finished():
1087
                    req.finished_reason = FINISH_ABORT()
1088
                    self.tree_cache.cache_finished_req(req)
1089
1090
1091
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1092
        """In-place update of the weights."""
1093
1094
1095
1096
1097
1098
1099
1100
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1115
1116
1117
1118
1119
1120

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1121
    dp_rank: Optional[int],
1122
    pipe_writer,
1123
):
1124
1125
1126
1127
1128
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1129
    suppress_other_loggers()
1130
1131

    try:
1132
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1133
        pipe_writer.send("ready")
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
1135
1136
1137
        if server_args.enable_overlap_schedule:
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1138
1139
1140
1141
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()