"test/vscode:/vscode.git/clone" did not exist on "2ba401f04591ad1580146cee28f9a914581fe381"
scheduler.py 39.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A scheduler that manages a tensor parallel GPU worker."""

18
import json
19
20
import logging
import multiprocessing
21
22
23
24
import os
import time
import warnings
from typing import List, Optional, Union
25

26
import torch
27
28
import zmq

29
30
31
32
33
34
35
36
37
38
39
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
    FlushCacheReq,
40
    ProfileReq,
41
42
43
44
45
46
47
48
49
50
51
52
53
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    TokenizedRewardReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
)
54
55
56
57
58
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
59
from sglang.srt.managers.tp_worker import TpModelWorker
60
61
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
62
from sglang.srt.server_args import PortArgs, ServerArgs
63
64
65
66
67
68
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
    is_generation_model,
    is_multimodal_model,
    kill_parent_process,
69
    pytorch_profile,
70
71
72
    set_random_seed,
    suppress_other_loggers,
)
73
74
75
76
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

77
78
79
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"

80
81
82
83
84
85
86
87
88
89
90
91

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
    ):
        # Parse args
92
        self.server_args = server_args
93
94
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
95
96
97
98
        self.schedule_policy = server_args.schedule_policy
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
99
100
101
102
103
104

        # Init inter-process communication
        context = zmq.Context(2)

        if self.tp_rank == 0:
            self.recv_from_tokenizer = context.socket(zmq.PULL)
105
            self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
106
107

            self.send_to_detokenizer = context.socket(zmq.PUSH)
108
            self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
109
        else:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            self.recv_from_tokenizer = self.send_to_detokenizer = None

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
            model_override_args=json.loads(server_args.json_model_override_args),
        )

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if is_multimodal_model(self.model_config.hf_config.architectures):
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
        self.is_generation = is_generation_model(
            self.model_config.hf_config.architectures, self.server_args.is_embedding
        )
139

140
        # Launch a tensor parallel worker
141
        self.tp_worker = TpModelWorker(
142
143
144
145
146
            gpu_id=gpu_id,
            tp_rank=tp_rank,
            server_args=server_args,
            nccl_port=port_args.nccl_ports[0],
        )
147
148
        self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group

149
        # Get token and memory info from the model worker
150
151
152
153
154
155
156
157
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
            self.max_req_input_len,
            self.random_seed,
        ) = self.tp_worker.get_token_and_memory_info()
        set_random_seed(self.random_seed)
158
159
160
        self.pad_input_ids_func = getattr(
            self.tp_worker.model_runner.model, "pad_input_ids", None
        )
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

        # Init cache
        self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
        self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
189
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

        # Init running status
        self.waiting_queue: List[Req] = []
        self.running_batch: ScheduleBatch = None
        self.out_pyobjs = []
        self.decode_forward_ct = 0
        self.stream_interval = server_args.stream_interval
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
        self.current_inflight_req = None
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

        # Init the FSM cache for constrained generation
        if not server_args.skip_tokenizer_init:
            self.regex_fsm_cache = FSMCache(
                server_args.tokenizer_path,
                {
                    "tokenizer_mode": server_args.tokenizer_mode,
                    "trust_remote_code": server_args.trust_remote_code,
                },
                skip_tokenizer_init=server_args.skip_tokenizer_init,
                constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
            )
        self.jump_forward_cache = JumpForwardCache()

        # Init new token estimation
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
        self.min_new_token_ratio = min(
            global_config.base_min_new_token_ratio
            * server_args.schedule_conservativeness,
            1.0,
        )
        self.new_token_ratio = self.min_new_token_ratio
        self.new_token_ratio_decay = global_config.new_token_ratio_decay
231
        self.batch_is_full = False
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
249
    @torch.inference_mode()
250
251
    def event_loop(self):
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
252
253
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
254

Lianmin Zheng's avatar
Lianmin Zheng committed
255
            self.run_step()
256

257
            self.send_results()
258

Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
263
264
265
266
267
268
269
270
    def recv_requests(self):
        if self.tp_rank == 0:
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
271

272
273
        if self.tp_size != 1:
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
274
275
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
276
    def process_input_requests(self, recv_reqs: List):
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
            elif isinstance(
                recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
            ):
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
                self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
291
292
293
294
295
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
            lora_path=recv_req.lora_path,
        )
        req.tokenizer = self.tokenizer

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
317
            req.origin_input_ids = self.pad_input_ids_func(
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
                req.origin_input_ids_unpadded, req.image_inputs
            )

        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

        # Init regex FSM
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
            if req.sampling_params.json_schema is not None:
                req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
                    ("json", req.sampling_params.json_schema)
                )
            elif req.sampling_params.regex is not None:
                req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
                    ("regex", req.sampling_params.regex)
                )
            if not self.disable_regex_jump_forward:
                req.jump_forward_map = self.jump_forward_cache.query(
                    computed_regex_string
                )

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
            self.max_req_input_len - 1 - len(req.origin_input_ids),
        )

        self.waiting_queue.append(req)

    def handle_embedding_request(
        self,
        recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

388
389
390
391
392
    def send_results(self):
        if self.tp_rank == 0:
            for obj in self.out_pyobjs:
                self.send_to_detokenizer.send_pyobj(obj)
            self.out_pyobjs = []
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

    def print_decode_stats(self):
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
        logger.info(
            f"Decode batch. "
            f"#running-req: {len(self.running_batch.reqs)}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"gen throughput (token/s): {throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )

    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
            warnings.warn(
                "Warning: "
                f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
                "KV cache pool leak detected!"
            )
            exit(1) if crash_on_warning else None

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
            warnings.warn(
                "Warning: "
                f"available req slots={len(self.req_to_token_pool.free_slots)}, "
                f"total slots={self.req_to_token_pool.size}\n"
                "Memory pool leak detected!"
            )
            exit(1) if crash_on_warning else None

431
432
433
434
    def run_step(self):
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
            # Run a new prefill batch
435
436
437
438
            # replace run_batch with the uncommented line to use pytorch profiler
            # result = pytorch_profile(
            #     "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
            # )
439
440
441
442
443
444
445
446
447
            result = self.run_batch(new_batch)
            self.process_batch_result(new_batch, result)
        else:
            if self.running_batch is not None:
                # Run a few decode batches continuously for reducing overhead
                for _ in range(global_config.num_continue_decode_steps):
                    batch = self.get_new_batch_decode()

                    if batch:
448
449
450
451
452
453
454
                        # replace run_batch with the uncommented line to use pytorch profiler
                        # result = pytorch_profile(
                        #     "profile_decode_step",
                        #     self.run_batch,
                        #     batch,
                        #     data_size=len(batch.reqs),
                        # )
455
456
457
                        result = self.run_batch(batch)
                        self.process_batch_result(batch, result)

458
459
460
                    if self.running_batch.is_empty():
                        self.running_batch = None

461
462
463
464
465
466
467
468
469
                    if self.running_batch is None:
                        break

                    if self.out_pyobjs and self.running_batch.has_stream:
                        break
            else:
                self.check_memory()
                self.new_token_ratio = global_config.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
470
471
472
473
474
475
476
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
        ) and self.current_inflight_req is None:
            return None

477
478
479
480
        running_bs = (
            len(self.running_batch.reqs) if self.running_batch is not None else 0
        )
        if running_bs >= self.max_running_requests:
481
            self.batch_is_full = True
482
483
484
485
486
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
487
        # Prefill policy
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        num_mixed_running = running_bs if self.is_mixed_chunk else 0
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
            num_mixed_running,
        )

        has_inflight = self.current_inflight_req is not None
        if self.current_inflight_req is not None:
            self.current_inflight_req.init_next_round_input(
                None if prefix_computed else self.tree_cache
            )
            self.current_inflight_req = adder.add_inflight_req(
                self.current_inflight_req
            )

        if self.lora_paths is not None:
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

        for req in self.waiting_queue:
            if (
                self.lora_paths is not None
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
525
                self.batch_is_full = True
526
527
                break

528
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
529
                self.batch_is_full = True
530
                break
531

532
533
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
534
535
536
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
537
538
539
540
541
542
543
544
545
546
547
                break

        can_run_list = adder.can_run_list

        if adder.new_inflight_req is not None:
            assert self.current_inflight_req is None
            self.current_inflight_req = adder.new_inflight_req

        if len(can_run_list) == 0:
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
548
549
        self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        # Print stats
        if self.tp_rank == 0:
            if isinstance(self.tree_cache, RadixCache):
                self.tree_cache_metrics["total"] += (
                    adder.log_input_tokens + adder.log_hit_tokens
                ) / 10**9
                self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
                tree_cache_hit_rate = (
                    self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
                )
            else:
                tree_cache_hit_rate = 0.0

            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )

            if num_mixed_running > 0:
                logger.info(
                    f"Prefill batch"
                    f"(mixed #running-req: {num_mixed_running}). "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
577
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
578
579
580
581
582
583
584
585
586
587
                )
            else:
                logger.info(
                    f"Prefill batch. "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
                    f"#running-req: {running_bs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
588
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
589
590
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
591
        # Create a new batch
592
593
594
595
596
597
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
598
        new_batch.prepare_for_extend(self.model_config.vocab_size)
599

Lianmin Zheng's avatar
Lianmin Zheng committed
600
        # Mixed-style chunked prefill
601
602
603
        decoding_reqs = []
        if self.is_mixed_chunk and self.running_batch is not None:
            self.running_batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
604
            new_batch.mix_with_running(self.running_batch)
605
606
            decoding_reqs = self.running_batch.reqs
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
607
608
609
610
611
612
613
614
615
616
617
618
619
        new_batch.decoding_reqs = decoding_reqs

        return new_batch

    def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
        batch = self.running_batch

        # Check if decode out of memory
        if not batch.check_decode_mem():
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
620

Lianmin Zheng's avatar
Lianmin Zheng committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
                self.new_token_ratio - self.new_token_ratio_decay,
                self.min_new_token_ratio,
            )

        # Check for jump-forward
        if not self.disable_regex_jump_forward:
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
                return None

        # Update batch tensors
        batch.prepare_for_decode()
        return batch

    def run_batch(self, batch: ScheduleBatch):
645
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
646
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
647
                model_worker_batch = batch.get_model_worker_batch()
648
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
649
                    model_worker_batch
650
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
651
652
653
            else:
                logits_output = None
                if self.tokenizer is not None:
654
655
656
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
657
                else:
658
                    next_token_ids = torch.full((batch.batch_size(),), 0)
Lianmin Zheng's avatar
Lianmin Zheng committed
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
            return logits_output, next_token_ids
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
            return embeddings

    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
        else:
            self.process_batch_result_prefill(batch, result)

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
        if self.is_generation:
            logits_output, next_token_ids = result
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                next_token_ids
            )
678

Lianmin Zheng's avatar
Lianmin Zheng committed
679
            if logits_output:
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
                # Move logprobs to cpu
                if logits_output.next_token_logprobs is not None:
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
                            torch.arange(
                                len(next_token_ids), device=next_token_ids.device
                            ),
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )

Lianmin Zheng's avatar
Lianmin Zheng committed
697
            next_token_ids = next_token_ids.tolist()
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714

            # Check finish conditions
            logprob_pt = 0
            for i, req in enumerate(batch.reqs):
                if req is not self.current_inflight_req:
                    # Inflight reqs' prefill is not finished
                    req.completion_tokens_wo_jump_forward += 1
                    req.output_ids.append(next_token_ids[i])
                    req.check_finished()

                if req.regex_fsm is not None:
                    req.regex_fsm_state = req.regex_fsm.get_next_state(
                        req.regex_fsm_state, next_token_ids[i]
                    )

                if req.finished():
                    self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
715
                elif req not in batch.decoding_reqs:
716
717
718
719
720
721
722
723
724
725
726
                    # To reduce overhead, only cache prefill reqs
                    self.tree_cache.cache_unfinished_req(req)

                if req is self.current_inflight_req:
                    # Inflight request would get a new req idx
                    self.req_to_token_pool.free(req.req_pool_idx)

                if req.return_logprob:
                    logprob_pt += self.add_logprob_return_values(
                        i, req, logprob_pt, next_token_ids, logits_output
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
727
        else:  # embedding or reward model
728
            assert batch.extend_num_tokens != 0
Lianmin Zheng's avatar
Lianmin Zheng committed
729
            embeddings = result
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
                req.embedding = embeddings[i]
                if req is not self.current_inflight_req:
                    # Inflight reqs' prefill is not finished
                    # dummy output token for embedding models
                    req.output_ids.append(0)
                    req.check_finished()

                if req.finished():
                    self.tree_cache.cache_finished_req(req)
                else:
                    self.tree_cache.cache_unfinished_req(req)

                if req is self.current_inflight_req:
                    # Inflight request would get a new req idx
                    self.req_to_token_pool.free(req.req_pool_idx)

        self.handle_finished_requests(batch)

751
752
753
754
755
756
        if not batch.is_empty():
            if self.running_batch is None:
                self.running_batch = batch
            else:
                self.running_batch.merge_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
        logits_output, next_token_ids = result
        batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
            next_token_ids
        )
        self.num_generated_tokens += len(batch.reqs)

        # Move logprobs to cpu
        if logits_output.next_token_logprobs is not None:
            next_token_logprobs = logits_output.next_token_logprobs[
                torch.arange(len(next_token_ids), device=next_token_ids.device),
                next_token_ids,
            ].tolist()

        next_token_ids = next_token_ids.tolist()

        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

            if req.regex_fsm is not None:
                req.regex_fsm_state = req.regex_fsm.get_next_state(
                    req.regex_fsm_state, next_token_id
                )

            if req.finished():
                self.tree_cache.cache_finished_req(req)

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

        self.handle_finished_requests(batch)

796
797
798
799
        self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
        if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
            self.print_decode_stats()

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

    def handle_finished_requests(self, batch: ScheduleBatch):
        output_rids = []
        output_meta_info = []
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
884
        else:  # embedding or reward model
885
886
887
888
889
890
            output_embeddings = []
        unfinished_indices = []

        for i, req in enumerate(batch.reqs):
            if not req.finished() and req is not self.current_inflight_req:
                unfinished_indices.append(i)
891
892
            else:
                self.batch_is_full = False
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940

            if req.finished() or (
                req.stream
                and (
                    self.decode_forward_ct % self.stream_interval == 0
                    or len(req.output_ids) == 1
                )
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
941
                else:  # embedding or reward model
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
                self.out_pyobjs.append(
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
964
            else:  # embedding or reward model
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
                self.out_pyobjs.append(
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

        # Remove finished reqs: update batch tensors
        batch.filter_batch(unfinished_indices)

    def flush_cache(self):
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
            self.regex_fsm_cache.reset()
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
                if req.rid == recv_req.rid:
                    req.finished_reason = FINISH_ABORT()
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1039
1040
1041
1042
1043
1044

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1045
    pipe_writer,
1046
1047
):
    configure_logger(server_args, prefix=f" TP{tp_rank}")
1048
    suppress_other_loggers()
1049
1050
1051
1052
1053
1054
1055
1056
1057

    try:
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
        pipe_writer.send("ready")
        scheduler.event_loop()
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()