scheduler.py 44.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A scheduler that manages a tensor parallel GPU worker."""

18
import json
19
import logging
20
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from collections import deque
25
from types import SimpleNamespace
26
from typing import List, Optional, Union
27

28
import torch
29
30
import zmq

31
32
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
33
from sglang.srt.constrained.grammar import GrammarCache
34
35
36
37
38
39
40
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
    FlushCacheReq,
41
42
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
43
    ProfileReq,
44
45
46
47
48
49
50
51
52
53
54
55
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    TokenizedRewardReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
56
    global_server_args_dict,
57
)
58
59
60
61
62
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
63
from sglang.srt.managers.tp_worker import TpModelWorker
64
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
65
66
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
67
from sglang.srt.server_args import PortArgs, ServerArgs
68
69
70
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
71
    get_zmq_socket,
72
73
74
75
76
77
    is_generation_model,
    is_multimodal_model,
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
78
79
80
81
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

82
83
84
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"

85
86
87
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"

88
89
90
91
92
93
94
95
96
97

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
98
        dp_rank: Optional[int],
99
100
    ):
        # Parse args
101
        self.server_args = server_args
102
103
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
104
105
106
107
        self.schedule_policy = server_args.schedule_policy
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
108
        self.enable_overlap = server_args.enable_overlap_schedule
109
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
110
111
112
113
114

        # Init inter-process communication
        context = zmq.Context(2)

        if self.tp_rank == 0:
115
116
117
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
118

119
120
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
121
122
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
123
124
125
                )
            else:
                # Send to the detokenizer
126
127
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
128
                )
129
        else:
130
131
            self.recv_from_tokenizer = None
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
            server_args.trust_remote_code,
            context_length=server_args.context_length,
            model_override_args=json.loads(server_args.json_model_override_args),
        )

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if is_multimodal_model(self.model_config.hf_config.architectures):
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
        self.is_generation = is_generation_model(
            self.model_config.hf_config.architectures, self.server_args.is_embedding
        )
160

161
        # Launch a tensor parallel worker
162
        if self.enable_overlap:
163
            TpWorkerClass = TpModelWorkerClient
164
165
        else:
            TpWorkerClass = TpModelWorker
166

167
        self.tp_worker = TpWorkerClass(
168
            server_args=server_args,
169
170
            gpu_id=gpu_id,
            tp_rank=tp_rank,
171
            dp_rank=dp_rank,
172
            nccl_port=port_args.nccl_port,
173
        )
174

175
        # Get token and memory info from the model worker
176
177
178
179
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
180
            self.max_req_len,
181
182
            self.max_req_input_len,
            self.random_seed,
183
            self.device,
184
185
186
187
188
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
189
190
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
191
        global_server_args_dict.update(worker_global_server_args_dict)
192
193
194
195
196
197
198
199
200
201
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

202
203
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
220
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
221
222
223

        # Init running status
        self.waiting_queue: List[Req] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
224
        self.running_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
225
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
        self.forward_ct = 0
        self.forward_ct_decode = 0
228
229
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
230
        self.stream_interval = server_args.stream_interval
231
232
233

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
234
        self.being_chunked_req = None
235
236
237
238
239
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

        # Init the FSM cache for constrained generation
240
241
        self.grammar_cache = None

242
        if not server_args.skip_tokenizer_init:
243
            self.grammar_cache = GrammarCache(
244
245
246
247
248
249
                server_args.tokenizer_path,
                {
                    "tokenizer_mode": server_args.tokenizer_mode,
                    "trust_remote_code": server_args.trust_remote_code,
                },
                skip_tokenizer_init=server_args.skip_tokenizer_init,
250
251
252
                whitespace_patterns=server_args.constrained_json_whitespace_pattern,
                backend=server_args.grammar_backend,
                allow_jump=not server_args.disable_regex_jump_forward,
253
            )
254
255

        # Init new token estimation
256
257
258
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
259
260
261

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
262
263
            * server_args.schedule_conservativeness,
            1.0,
264
        )
265
266
267
268
269
270
271
272
273
274
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

275
        self.batch_is_full = False
276

Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
279
280
281
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()

282
        # Init profiler
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def watchdog_thread(self):
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

        kill_parent_process()

Lianmin Zheng's avatar
Lianmin Zheng committed
316
    @torch.inference_mode()
317
    def event_loop_normal(self):
318
        """A normal blocking scheduler loop."""
319
320
        self.last_batch = None

321
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
324

325
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
326
            self.cur_batch = batch
327
328
329
330

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
331

332
333
334
335
336
337
338
339
340
341
                # Decode multiple steps to reduce the overhead
                if batch.forward_mode.is_decode():
                    for _ in range(self.server_args.num_continuous_decode_steps - 1):
                        if not self.running_batch:
                            break
                        self.update_running_batch()
                        if not self.running_batch:
                            break
                        result = self.run_batch(batch)
                        self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
            else:
                self.check_memory()
344
                self.new_token_ratio = self.init_new_token_ratio
345
346

            self.last_batch = batch
347

Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
    @torch.inference_mode()
    def event_loop_overlap(self):
350
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        result_queue = deque()

        self.last_batch = None
        self.running_batch = None

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
                self.check_memory()
371
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
372
373
374

            self.last_batch = batch

Lianmin Zheng's avatar
Lianmin Zheng committed
375
376
377
378
379
380
381
382
383
384
385
386
    def recv_requests(self):
        if self.tp_rank == 0:
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
387

388
389
        if self.tp_size != 1:
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
390
391
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
392
    def process_input_requests(self, recv_reqs: List):
393
394
395
396
397
398
399
400
401
402
403
404
405
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
            elif isinstance(
                recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
            ):
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
406
407
408
                self.send_to_detokenizer.send_pyobj(
                    UpdateWeightReqOutput(success, message)
                )
409
410
411
412
413
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
414
415
416
417
            elif isinstance(recv_req, GetMemPoolSizeReq):
                self.send_to_detokenizer.send_pyobj(
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
            lora_path=recv_req.lora_path,
        )
        req.tokenizer = self.tokenizer

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
439
            req.origin_input_ids = self.pad_input_ids_func(
440
441
442
443
444
445
446
447
448
449
450
451
                req.origin_input_ids_unpadded, req.image_inputs
            )

        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

452
        # Init regex FSM or BNF
453
454
455
456
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
457
            assert self.grammar_cache is not None
458
            if req.sampling_params.json_schema is not None:
459
460
461
                req.grammar = self.grammar_cache.query(
                    ("json", req.sampling_params.json_schema),
                    self.model_config.vocab_size,
462
463
                )
            elif req.sampling_params.regex is not None:
464
465
                req.grammar = self.grammar_cache.query(
                    ("regex", req.sampling_params.regex), self.model_config.vocab_size
466
467
468
                )

        # Truncate prompts that are too long
469
        if len(req.origin_input_ids) > self.max_req_input_len:
470
471
472
473
474
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
475

476
477
478
479
480
481
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
482
            self.max_req_len - len(req.origin_input_ids) - 1,
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        )

        self.waiting_queue.append(req)

    def handle_embedding_request(
        self,
        recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
509
510
511
512
513
514
515
    def print_decode_stats(self):
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
        self.num_generated_tokens = 0
        self.last_stats_tic = time.time()
516
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
517
518
        logger.info(
            f"Decode batch. "
519
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"gen throughput (token/s): {throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )

    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
            warnings.warn(
                "Warning: "
                f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
                "KV cache pool leak detected!"
            )
            exit(1) if crash_on_warning else None

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
            warnings.warn(
                "Warning: "
                f"available req slots={len(self.req_to_token_pool.free_slots)}, "
                f"total slots={self.req_to_token_pool.size}\n"
                "Memory pool leak detected!"
            )
            exit(1) if crash_on_warning else None

547
    def get_next_batch_to_run(self):
548
        # Merge the prefill batch into the running batch
549
550
551
552
553
        if (
            self.last_batch
            and not self.last_batch.forward_mode.is_decode()
            and not self.last_batch.is_empty()
        ):
554
            if self.being_chunked_req:
555
                self.last_batch.filter_batch(
556
                    being_chunked_req=self.being_chunked_req
557
                )
558
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
559
                # Inflight request keeps its rid but will get a new req_pool_idx.
560
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
561
562
563
564
565
566
                self.batch_is_full = False
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
567
568

        # Prefill first
569
570
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
571
            return new_batch
572

573
574
575
576
577
578
579
580
581
582
583
584
585
        # Check memory
        if self.running_batch is None:
            return

        # Run decode
        before_bs = self.running_batch.batch_size()
        self.update_running_batch()
        if not self.running_batch:
            self.batch_is_full = False
            return None
        if before_bs != self.running_batch.batch_size():
            self.batch_is_full = False
        return self.running_batch
586

Lianmin Zheng's avatar
Lianmin Zheng committed
587
588
589
590
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
591
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
592
593
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
594
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
595
        if running_bs >= self.max_running_requests:
596
            self.batch_is_full = True
597
598
599
600
601
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
602
        # Prefill policy
603
604
605
606
607
608
609
610
611
612
613
        num_mixed_running = running_bs if self.is_mixed_chunk else 0
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
            num_mixed_running,
        )

614
        has_inflight = self.being_chunked_req is not None
615
        if has_inflight:
616
617
618
            self.being_chunked_req.init_next_round_input()
            self.being_chunked_req = adder.add_inflight_req(
                self.being_chunked_req
619
620
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
621
        if self.lora_paths:
622
623
624
625
626
627
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

628
        # Get requests from the waiting queue to a new prefill batch
629
630
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
631
                self.lora_paths
632
633
634
635
636
637
638
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
639
                self.batch_is_full = True
640
641
                break

642
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
643
                self.batch_is_full = True
644
                break
645

646
647
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
648
649
650
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
651
652
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
653
        # Update waiting queue
654
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
655
656
657
658
659
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
660

661
        if adder.new_inflight_req is not None:
662
663
            assert self.being_chunked_req is None
            self.being_chunked_req = adder.new_inflight_req
664

665
666
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
667

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        # Print stats
        if self.tp_rank == 0:
            if isinstance(self.tree_cache, RadixCache):
                self.tree_cache_metrics["total"] += (
                    adder.log_input_tokens + adder.log_hit_tokens
                ) / 10**9
                self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
                tree_cache_hit_rate = (
                    self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
                )
            else:
                tree_cache_hit_rate = 0.0

            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )

            if num_mixed_running > 0:
                logger.info(
                    f"Prefill batch"
                    f"(mixed #running-req: {num_mixed_running}). "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
695
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
696
697
698
699
700
701
702
703
704
705
                )
            else:
                logger.info(
                    f"Prefill batch. "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
                    f"#running-req: {running_bs}, "
706
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
707
708
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
709
        # Create a new batch
710
711
712
713
714
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
715
            self.model_config,
716
        )
717
        new_batch.prepare_for_extend()
718

Lianmin Zheng's avatar
Lianmin Zheng committed
719
        # Mixed-style chunked prefill
720
        if self.is_mixed_chunk and self.running_batch is not None:
721
722
723
724
725
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
                self.running_batch.prepare_for_decode(self.enable_overlap)
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
726
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
729
730
731

        return new_batch

732
    def update_running_batch(self):
733
        """Update the current running decoding batch."""
734
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
735
736
        batch = self.running_batch

737
738
739
740
741
        batch.filter_batch()
        if batch.is_empty():
            self.running_batch = None
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
742
        # Check if decode out of memory
743
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
744
745
746
747
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
748

Lianmin Zheng's avatar
Lianmin Zheng committed
749
750
751
752
753
754
755
756
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
757
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
758
759
760
761
762
763
764
765
                self.min_new_token_ratio,
            )

        # Check for jump-forward
        if not self.disable_regex_jump_forward:
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
766
767
                self.running_batch = None
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
768
769

        # Update batch tensors
770
        batch.prepare_for_decode(self.enable_overlap)
Lianmin Zheng's avatar
Lianmin Zheng committed
771
772

    def run_batch(self, batch: ScheduleBatch):
773
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
774
775
        self.forward_ct += 1

776
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
777
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
778
                model_worker_batch = batch.get_model_worker_batch()
779
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
780
                    model_worker_batch
781
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
782
783
            else:
                logits_output = None
784
                if self.skip_tokenizer_init:
785
786
787
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
788
                else:
789
                    next_token_ids = torch.full((batch.batch_size(),), 0)
790
            batch.output_ids = next_token_ids
791
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
792
793
794
795
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
796
            ret = embeddings, model_worker_batch.bid
797
        return ret
Lianmin Zheng's avatar
Lianmin Zheng committed
798
799
800
801

    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
802
803
            if batch.is_empty():
                self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
804
805
806
807
        else:
            self.process_batch_result_prefill(batch, result)

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
808

Lianmin Zheng's avatar
Lianmin Zheng committed
809
        if self.is_generation:
810
            logits_output, next_token_ids, bid = result
811
812
813
814
815
816

            if self.enable_overlap:
                logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
817
818
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
819
                            torch.arange(len(next_token_ids), device=self.device),
820
821
822
823
824
825
826
827
828
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
829
                next_token_ids = next_token_ids.tolist()
830
831
832
833

            # Check finish conditions
            logprob_pt = 0
            for i, req in enumerate(batch.reqs):
834
835
                if req.is_being_chunked > 0:
                    req.is_being_chunked -= 1
836
837
                else:
                    # Inflight reqs' prefill is not finished
838
839
840
841
                    req.completion_tokens_wo_jump_forward += 1
                    req.output_ids.append(next_token_ids[i])
                    req.check_finished()

842
                    if req.finished():
843
                        self.tree_cache.cache_finished_req(req)
844
845
846
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

847
848
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_ids[i])
849
850
851
852
853

                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
854
        else:  # embedding or reward model
855
856
            embeddings, bid = result
            embeddings = embeddings.tolist()
857
858
859
860

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
                req.embedding = embeddings[i]
861
862
                if req.is_being_chunked > 0:
                    req.is_being_chunked -= 1
863
864
                else:
                    # Inflight reqs' prefill is not finished
865
866
867
868
869
                    # dummy output token for embedding models
                    req.output_ids.append(0)
                    req.check_finished()

                if req.finished():
870
                    self.tree_cache.cache_finished_req(req)
871
872
873
                else:
                    self.tree_cache.cache_unfinished_req(req)

874
        self.stream_output(batch.reqs)
875

Lianmin Zheng's avatar
Lianmin Zheng committed
876
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
877
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
878
879
        self.num_generated_tokens += len(batch.reqs)

880
881
        if self.enable_overlap:
            logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
882
            next_token_logprobs = logits_output.next_token_logprobs
883
884
885
886
887
888
889
890
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
891

892
893
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
894
895
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
Lianmin Zheng's avatar
Lianmin Zheng committed
896
            if self.server_args.enable_overlap_schedule and req.finished():
897
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
898
899
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
900
901
902
903
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

904
905
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907

            if req.finished():
908
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
909
910
911
912
913
914
915
916

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

917
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
918

919
920
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
921
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
922
        if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
923
924
            self.print_decode_stats()

925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

998
    def stream_output(self, reqs: List[Req]):
999
        """Stream the output to detokenizer."""
1000
        output_rids = []
1001
        output_meta_info: List[dict] = []
1002
1003
1004
1005
1006
1007
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1008
            output_ids = []
1009
1010
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1011
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1012
        else:  # embedding or reward model
1013
1014
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1015
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1016

1017
        for req in reqs:
1018
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1019
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1020
1021
1022
1023
1024
1025
1026
1027
1028
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1029
1030
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1031
1032
1033
1034
1035
1036
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1037
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1038
1039
1040
1041
1042

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1043
                        "cached_tokens": req.cached_tokens,
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1065
                else:  # embedding or reward model
1066
1067
1068
1069
1070
1071
1072
1073
1074
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1075
                self.send_to_detokenizer.send_pyobj(
1076
1077
1078
1079
1080
1081
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1082
                        output_ids,
1083
1084
1085
1086
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1087
                        output_no_stop_trim,
1088
1089
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1090
            else:  # embedding or reward model
1091
                self.send_to_detokenizer.send_pyobj(
1092
1093
1094
1095
1096
1097
1098
1099
1100
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

    def flush_cache(self):
1101
        """Flush the memory pool and cache."""
1102
1103
1104
1105
1106
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1107
1108
1109
            if self.grammar_cache is not None:
                self.grammar_cache.reset()
            # TODO(dark): reset the bnf cache
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1138
                if req.rid == recv_req.rid and not req.finished():
1139
                    req.finished_reason = FINISH_ABORT()
1140
                    self.tree_cache.cache_finished_req(req)
1141
1142
1143
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1144
        """In-place update of the weights."""
1145
1146
1147
1148
1149
1150
1151
1152
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1167
1168
1169
1170
1171
1172

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1173
    dp_rank: Optional[int],
1174
    pipe_writer,
1175
):
1176
1177
1178
1179
1180
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1181
    suppress_other_loggers()
1182
1183

    try:
1184
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1185
        pipe_writer.send("ready")
Lianmin Zheng's avatar
Lianmin Zheng committed
1186
1187
1188
1189
        if server_args.enable_overlap_schedule:
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1190
1191
1192
1193
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()