scheduler.py 49.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A scheduler that manages a tensor parallel GPU worker."""

import logging
19
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import threading
21
22
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
24
from types import SimpleNamespace
25
from typing import List, Optional
26

27
import torch
28
29
import zmq

30
31
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
32
from sglang.srt.constrained.grammar import GrammarCache
33
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
    FlushCacheReq,
40
41
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
42
    ProfileReq,
43
44
45
46
47
48
49
50
51
52
53
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
54
    global_server_args_dict,
55
)
56
57
58
59
60
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
61
from sglang.srt.managers.tp_worker import TpModelWorker
62
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
64
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
65
66
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
67
from sglang.srt.server_args import PortArgs, ServerArgs
68
69
70
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
71
    get_zmq_socket,
72
73
74
75
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
76
77
78
79
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

80

81
82
83
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"

84
85
86
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"

87
88
89
90
91
92
93
94
95
96

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
97
        dp_rank: Optional[int],
98
99
    ):
        # Parse args
100
        self.server_args = server_args
101
102
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
103
104
105
106
        self.schedule_policy = server_args.schedule_policy
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
107
        self.enable_overlap = server_args.enable_overlap_schedule
108
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
109
110
111
112
113

        # Init inter-process communication
        context = zmq.Context(2)

        if self.tp_rank == 0:
114
115
116
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
117

118
119
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
120
121
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
122
123
124
                )
            else:
                # Send to the detokenizer
125
126
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
127
                )
128
        else:
129
130
            self.recv_from_tokenizer = None
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
131
132
133
134

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
135
            trust_remote_code=server_args.trust_remote_code,
136
            context_length=server_args.context_length,
137
138
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
139
        )
140
        self.is_generation = self.model_config.is_generation
141
142
143
144

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
145
            if self.model_config.is_multimodal:
146
147
148
149
150
151
152
153
154
155
156
157
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
158

159
        # Launch a tensor parallel worker
160
        if self.enable_overlap:
161
            TpWorkerClass = TpModelWorkerClient
162
163
        else:
            TpWorkerClass = TpModelWorker
164

165
        self.tp_worker = TpWorkerClass(
166
            server_args=server_args,
167
168
            gpu_id=gpu_id,
            tp_rank=tp_rank,
169
            dp_rank=dp_rank,
170
            nccl_port=port_args.nccl_port,
171
        )
172

173
        # Get token and memory info from the model worker
174
175
176
177
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
178
            self.max_req_len,
179
180
            self.max_req_input_len,
            self.random_seed,
181
            self.device,
182
183
184
185
186
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
187
188
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
189
        global_server_args_dict.update(worker_global_server_args_dict)
190
191
192
193
194
195
196
197
198
199
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

200
201
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
218
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
219
220
221

        # Init running status
        self.waiting_queue: List[Req] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
222
        self.running_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
223
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
        self.forward_ct = 0
        self.forward_ct_decode = 0
226
        self.num_generated_tokens = 0
Chayenne's avatar
Chayenne committed
227
228
        self.last_stats_tic = time.time()  # time of last stats for every iter
        self.last_log_tic = time.time()  # time of last log for print decode log
Lianmin Zheng's avatar
Lianmin Zheng committed
229
        self.stream_interval = server_args.stream_interval
230
231
232

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
233
        self.being_chunked_req = None
234
235
236
237
238
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

        # Init the FSM cache for constrained generation
239
240
        self.grammar_cache = None

241
        if not server_args.skip_tokenizer_init:
242
            self.grammar_cache = GrammarCache(
243
244
245
246
247
248
                server_args.tokenizer_path,
                {
                    "tokenizer_mode": server_args.tokenizer_mode,
                    "trust_remote_code": server_args.trust_remote_code,
                },
                skip_tokenizer_init=server_args.skip_tokenizer_init,
249
250
251
                whitespace_patterns=server_args.constrained_json_whitespace_pattern,
                backend=server_args.grammar_backend,
                allow_jump=not server_args.disable_regex_jump_forward,
252
            )
253
254

        # Init new token estimation
255
256
257
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
258
259
260

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
261
262
            * server_args.schedule_conservativeness,
            1.0,
263
        )
264
265
266
267
268
269
270
271
272
273
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

274
        self.batch_is_full = False
275

Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
278
279
280
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()

281
        # Init profiler
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
297
298
299
300
301
302
303
304
305
        # Init metrics stats
        self.stats = Stats()
        self.metrics_collector = PrometheusMetricsCollector(
            labels={
                "model_name": self.server_args.served_model_name,
                # TODO: Add lora name/path in the future,
            },
            max_model_len=self.max_total_num_tokens,
        )
306

Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    def watchdog_thread(self):
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

        kill_parent_process()

Lianmin Zheng's avatar
Lianmin Zheng committed
324
    @torch.inference_mode()
325
    def event_loop_normal(self):
326
        """A normal blocking scheduler loop."""
327
328
        self.last_batch = None

329
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
330
331
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
332

333
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
334
            self.cur_batch = batch
335
336
337
338

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
339

340
341
342
343
344
345
346
347
348
349
                # Decode multiple steps to reduce the overhead
                if batch.forward_mode.is_decode():
                    for _ in range(self.server_args.num_continuous_decode_steps - 1):
                        if not self.running_batch:
                            break
                        self.update_running_batch()
                        if not self.running_batch:
                            break
                        result = self.run_batch(batch)
                        self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351
            else:
                self.check_memory()
352
                self.new_token_ratio = self.init_new_token_ratio
353
354
355
356
357
            # log stats
            if self.is_generation and self.server_args.enable_metrics:
                stats = self.get_stats(batch)
                self.log_stats(stats)
            self.last_stats_tic = time.time()
358
359

            self.last_batch = batch
360

Lianmin Zheng's avatar
Lianmin Zheng committed
361
362
    @torch.inference_mode()
    def event_loop_overlap(self):
363
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        result_queue = deque()

        self.last_batch = None
        self.running_batch = None

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
                self.check_memory()
384
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
387

            self.last_batch = batch

Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
392
393
394
395
396
397
398
399
    def recv_requests(self):
        if self.tp_rank == 0:
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
400

401
402
        if self.tp_size != 1:
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
403
404
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
405
    def process_input_requests(self, recv_reqs: List):
406
407
408
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
409
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
410
411
412
413
414
415
416
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
417
418
419
                self.send_to_detokenizer.send_pyobj(
                    UpdateWeightReqOutput(success, message)
                )
420
421
422
423
424
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
425
426
427
428
            elif isinstance(recv_req, GetMemPoolSizeReq):
                self.send_to_detokenizer.send_pyobj(
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
            lora_path=recv_req.lora_path,
        )
        req.tokenizer = self.tokenizer

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
450
            req.origin_input_ids = self.pad_input_ids_func(
451
452
453
454
455
456
457
458
459
460
461
462
                req.origin_input_ids_unpadded, req.image_inputs
            )

        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

463
        # Init regex FSM or BNF
464
465
466
467
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
468
            assert self.grammar_cache is not None
469
            if req.sampling_params.json_schema is not None:
470
471
472
                req.grammar = self.grammar_cache.query(
                    ("json", req.sampling_params.json_schema),
                    self.model_config.vocab_size,
473
474
                )
            elif req.sampling_params.regex is not None:
475
476
                req.grammar = self.grammar_cache.query(
                    ("regex", req.sampling_params.regex), self.model_config.vocab_size
477
478
479
                )

        # Truncate prompts that are too long
480
        if len(req.origin_input_ids) > self.max_req_input_len:
481
482
483
484
485
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
486

487
488
489
490
491
492
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
493
            self.max_req_len - len(req.origin_input_ids) - 1,
494
495
        )

496
        req.created_time = time.time()
497
498
499
500
        self.waiting_queue.append(req)

    def handle_embedding_request(
        self,
501
        recv_req: TokenizedEmbeddingReqInput,
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

Lianmin Zheng's avatar
Lianmin Zheng committed
521
522
523
524
    def print_decode_stats(self):
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
525
        throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
Lianmin Zheng's avatar
Lianmin Zheng committed
526
        self.num_generated_tokens = 0
527
528
529
        self.last_log_tic = time.time()
        # set system stats
        self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
530
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
        logger.info(
            f"Decode batch. "
533
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"gen throughput (token/s): {throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )

    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
            warnings.warn(
                "Warning: "
                f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
                "KV cache pool leak detected!"
            )
            exit(1) if crash_on_warning else None

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
            warnings.warn(
                "Warning: "
                f"available req slots={len(self.req_to_token_pool.free_slots)}, "
                f"total slots={self.req_to_token_pool.size}\n"
                "Memory pool leak detected!"
            )
            exit(1) if crash_on_warning else None

561
    def get_next_batch_to_run(self):
562
        # Merge the prefill batch into the running batch
563
564
565
566
567
        if (
            self.last_batch
            and not self.last_batch.forward_mode.is_decode()
            and not self.last_batch.is_empty()
        ):
568
            if self.being_chunked_req:
Chayenne's avatar
Chayenne committed
569
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
570
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
571
                # Inflight request keeps its rid but will get a new req_pool_idx.
572
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
573
574
575
576
577
578
                self.batch_is_full = False
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
579
580

        # Prefill first
581
582
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
583
            return new_batch
584

585
586
587
588
589
590
591
592
593
594
595
596
597
        # Check memory
        if self.running_batch is None:
            return

        # Run decode
        before_bs = self.running_batch.batch_size()
        self.update_running_batch()
        if not self.running_batch:
            self.batch_is_full = False
            return None
        if before_bs != self.running_batch.batch_size():
            self.batch_is_full = False
        return self.running_batch
598

Lianmin Zheng's avatar
Lianmin Zheng committed
599
600
601
602
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
603
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
604
605
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
606
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
607
        if running_bs >= self.max_running_requests:
608
            self.batch_is_full = True
609
610
611
612
613
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
614
        # Prefill policy
615
616
617
618
619
620
621
622
623
624
625
        num_mixed_running = running_bs if self.is_mixed_chunk else 0
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
            num_mixed_running,
        )

626
        has_inflight = self.being_chunked_req is not None
627
        if has_inflight:
628
            self.being_chunked_req.init_next_round_input()
Chayenne's avatar
Chayenne committed
629
            self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
630

Lianmin Zheng's avatar
Lianmin Zheng committed
631
        if self.lora_paths:
632
633
634
635
636
637
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

638
        # Get requests from the waiting queue to a new prefill batch
639
640
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
641
                self.lora_paths
642
643
644
645
646
647
648
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
649
                self.batch_is_full = True
650
651
                break

652
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
653
                self.batch_is_full = True
654
                break
655

656
657
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
658
659
660
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
661
662
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
663
        # Update waiting queue
664
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
667
668
669
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
670

671
        if adder.new_inflight_req is not None:
672
673
            assert self.being_chunked_req is None
            self.being_chunked_req = adder.new_inflight_req
674

675
676
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
677

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        # Print stats
        if self.tp_rank == 0:
            if isinstance(self.tree_cache, RadixCache):
                self.tree_cache_metrics["total"] += (
                    adder.log_input_tokens + adder.log_hit_tokens
                ) / 10**9
                self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
                tree_cache_hit_rate = (
                    self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
                )
            else:
                tree_cache_hit_rate = 0.0

            num_used = self.max_total_num_tokens - (
                self.token_to_kv_pool.available_size()
                + self.tree_cache.evictable_size()
            )
695
696
697
            # set system stats
            self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
698
699
700
701
702
703
704
705
706
707

            if num_mixed_running > 0:
                logger.info(
                    f"Prefill batch"
                    f"(mixed #running-req: {num_mixed_running}). "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
708
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
709
710
711
712
713
714
715
716
717
718
                )
            else:
                logger.info(
                    f"Prefill batch. "
                    f"#new-seq: {len(can_run_list)}, "
                    f"#new-token: {adder.log_input_tokens}, "
                    f"#cached-token: {adder.log_hit_tokens}, "
                    f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
                    f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
                    f"#running-req: {running_bs}, "
719
                    f"#queue-req: {len(self.waiting_queue) + has_inflight}"
720
721
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
722
        # Create a new batch
723
724
725
726
727
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
728
            self.model_config,
729
        )
730
        new_batch.prepare_for_extend()
731

Lianmin Zheng's avatar
Lianmin Zheng committed
732
        # Mixed-style chunked prefill
733
        if self.is_mixed_chunk and self.running_batch is not None:
734
735
736
737
738
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
                self.running_batch.prepare_for_decode(self.enable_overlap)
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
739
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
740
741
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
742
743
744

        return new_batch

745
    def update_running_batch(self):
746
        """Update the current running decoding batch."""
747
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
748
749
        batch = self.running_batch

750
751
752
753
754
        batch.filter_batch()
        if batch.is_empty():
            self.running_batch = None
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
755
        # Check if decode out of memory
756
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
757
758
759
760
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
761

Lianmin Zheng's avatar
Lianmin Zheng committed
762
763
764
765
766
767
768
769
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
770
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
771
772
773
774
775
776
777
778
                self.min_new_token_ratio,
            )

        # Check for jump-forward
        if not self.disable_regex_jump_forward:
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
779
780
                self.running_batch = None
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
781
782

        # Update batch tensors
783
        batch.prepare_for_decode(self.enable_overlap)
Lianmin Zheng's avatar
Lianmin Zheng committed
784
785

    def run_batch(self, batch: ScheduleBatch):
786
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
787
788
        self.forward_ct += 1

789
        if self.is_generation:
Lianmin Zheng's avatar
Lianmin Zheng committed
790
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
791
                model_worker_batch = batch.get_model_worker_batch()
792
                batch.mark_reqs_started()
793
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
794
                    model_worker_batch
795
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
796
797
            else:
                logits_output = None
798
                if self.skip_tokenizer_init:
799
800
801
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
802
                else:
803
                    next_token_ids = torch.full((batch.batch_size(),), 0)
804
            batch.output_ids = next_token_ids
805
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
806
807
808
809
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
810
            ret = embeddings, model_worker_batch.bid
811
        return ret
Chayenne's avatar
Chayenne committed
812
813

    def get_stats(self, batch: ScheduleBatch):
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        # TODO: get stats for chunked prefill

        now = time.time()
        # system stats
        #   Scheduler State
        new_seq: int = 0
        num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
        num_waiting_req = len(self.waiting_queue)
        #   Cache State
        cache_hit_rate: float = 0.0
        token_usage: float = 0.0

        # set stats from prefill
        if self.stats is not None:
            # new_seq=self.stats.new_seq
Chayenne's avatar
Chayenne committed
829
830
            cache_hit_rate = self.stats.cache_hit_rate
            token_usage = self.stats.token_usage
831
832
833
834
835
836
837
        # Iteration stats
        num_prompt_tokens_iter = 0
        num_generation_tokens_iter = 0
        time_to_first_tokens_iter: List[float] = []
        time_per_output_tokens_iter: List[float] = []

        # Request stats
838
        #   Decode
839
840
841
842
843
844
845
846
847
848
849
850
        gen_throughput: float = 0.0
        #   Latency
        time_e2e_requests: List[float] = []
        time_waiting_requests: List[float] = []
        #   Metadata
        num_prompt_tokens_requests: List[int] = []
        num_generation_tokens_requests: List[int] = []
        finished_reason_requests: List[str] = []

        # _, next_token_ids, _ = result
        if batch is not None:
            num_generation_tokens_iter = len(batch.output_ids)
Chayenne's avatar
Chayenne committed
851
852
853
            gen_throughput = round(
                num_generation_tokens_iter / (now - self.last_stats_tic), 2
            )
854
855
856
857

            for i, req in enumerate(batch.reqs):
                # NOTE: Batch forward mode is extend befor start decode,
                if batch.forward_mode.is_extend():
Chayenne's avatar
Chayenne committed
858
859
860
                    num_prompt_tokens_iter = len(batch.input_ids) + sum(
                        batch.prefix_lens
                    )
861
862
                    time_to_first_tokens_iter.append(now - req.started_time)
                else:
Chayenne's avatar
Chayenne committed
863
                    time_per_output_tokens_iter.append(now - self.last_stats_tic)
864
865
866
867
868
869

                if req.finished():
                    time_e2e_requests.append(now - req.created_time)
                    time_waiting_requests.append(req.queued_time - req.created_time)
                    num_prompt_tokens_requests.append(len(req.origin_input_ids))
                    num_generation_tokens_requests.append(len(req.output_ids))
870
                    finished_reason_requests.append(
Chayenne's avatar
Chayenne committed
871
872
873
874
                        req.finished_reason.to_json()
                        if req.finished_reason is not None
                        else None
                    )
875

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        return Stats(
            new_seq=new_seq,
            num_running_req=num_running_req,
            num_waiting_req=num_waiting_req,
            cache_hit_rate=cache_hit_rate,
            token_usage=token_usage,
            num_prompt_tokens_iter=num_prompt_tokens_iter,
            num_generation_tokens_iter=num_generation_tokens_iter,
            time_to_first_tokens_iter=time_to_first_tokens_iter,
            time_per_output_tokens_iter=time_per_output_tokens_iter,
            gen_throughput=gen_throughput,
            time_e2e_requests=time_e2e_requests,
            time_waiting_requests=time_waiting_requests,
            num_prompt_tokens_requests=num_prompt_tokens_requests,
            num_generation_tokens_requests=num_generation_tokens_requests,
            finished_reason_requests=finished_reason_requests,
            context_len=self.model_config.context_len,
            max_total_num_tokens=self.max_total_num_tokens,
            max_prefill_tokens=self.max_prefill_tokens,
            max_running_requests=self.max_running_requests,
        )

Chayenne's avatar
Chayenne committed
898
    def log_stats(self, stats: Stats):
899
        self.metrics_collector.log_stats(stats)
Lianmin Zheng's avatar
Lianmin Zheng committed
900
901
902
903

    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
904
905
            if batch.is_empty():
                self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907
908
909
        else:
            self.process_batch_result_prefill(batch, result)

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
910

Lianmin Zheng's avatar
Lianmin Zheng committed
911
        if self.is_generation:
912
            logits_output, next_token_ids, bid = result
913
914
915
916
917
918

            if self.enable_overlap:
                logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
919
920
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
921
                            torch.arange(len(next_token_ids), device=self.device),
922
923
924
925
926
927
928
929
930
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
931
                next_token_ids = next_token_ids.tolist()
932
933
934
935

            # Check finish conditions
            logprob_pt = 0
            for i, req in enumerate(batch.reqs):
936
937
938
939
                if req.is_retracted:
                    continue

                if req.is_being_chunked <= 0:
940
                    # Inflight reqs' prefill is not finished
941
942
943
944
                    req.completion_tokens_wo_jump_forward += 1
                    req.output_ids.append(next_token_ids[i])
                    req.check_finished()

945
                    if req.finished():
946
                        self.tree_cache.cache_finished_req(req)
947
948
949
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

950
951
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_ids[i])
952
953
954
955
956

                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
957
958
959
                else:
                    req.is_being_chunked -= 1

Lianmin Zheng's avatar
Lianmin Zheng committed
960
        else:  # embedding or reward model
961
962
            embeddings, bid = result
            embeddings = embeddings.tolist()
963
964
965

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
966
967
968
                if req.is_retracted:
                    continue

969
                req.embedding = embeddings[i]
970
971
                if req.is_being_chunked > 0:
                    req.is_being_chunked -= 1
972
973
                else:
                    # Inflight reqs' prefill is not finished
974
975
976
977
978
                    # dummy output token for embedding models
                    req.output_ids.append(0)
                    req.check_finished()

                if req.finished():
979
                    self.tree_cache.cache_finished_req(req)
980
981
982
                else:
                    self.tree_cache.cache_unfinished_req(req)

983
        self.stream_output(batch.reqs)
984

Lianmin Zheng's avatar
Lianmin Zheng committed
985
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
986
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
987
988
        self.num_generated_tokens += len(batch.reqs)

989
990
        if self.enable_overlap:
            logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
991
            next_token_logprobs = logits_output.next_token_logprobs
992
993
994
995
996
997
998
999
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1000

1001
1002
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1003
1004
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1005
1006
1007
            if req.is_retracted:
                continue

Chayenne's avatar
Chayenne committed
1008
            if self.server_args.enable_overlap_schedule and (req.finished()):
1009
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1010
1011
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1012
1013
1014
1015
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

1016
1017
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
1018
1019

            if req.finished():
1020
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
1022
1023
1024
1025
1026
1027
1028

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

1029
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1030

1031
1032
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1033
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1034
1035
1036
1037
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1038
1039
            self.print_decode_stats()

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

1113
    def stream_output(self, reqs: List[Req]):
1114
        """Stream the output to detokenizer."""
1115
        output_rids = []
1116
        output_meta_info: List[dict] = []
1117
1118
1119
1120
1121
1122
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1123
            output_ids = []
1124
1125
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1126
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1127
        else:  # embedding or reward model
1128
1129
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1130
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1131

1132
        for req in reqs:
1133
            # TODO(lianmin): revisit this for overlap + retract + stream
1134
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1135
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1136
1137
1138
1139
1140
1141
1142
1143
1144
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1145
1146
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1147
1148
1149
1150
1151
1152
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1153
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1154
1155
1156
1157
1158

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1159
                        "cached_tokens": req.cached_tokens,
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1181
                else:  # embedding or reward model
1182
1183
1184
1185
1186
1187
1188
1189
1190
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1191
                self.send_to_detokenizer.send_pyobj(
1192
1193
1194
1195
1196
1197
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1198
                        output_ids,
1199
1200
1201
1202
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1203
                        output_no_stop_trim,
1204
1205
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
            else:  # embedding or reward model
1207
                self.send_to_detokenizer.send_pyobj(
1208
1209
1210
1211
1212
1213
1214
1215
1216
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

    def flush_cache(self):
1217
        """Flush the memory pool and cache."""
1218
1219
1220
1221
1222
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1223
1224
1225
            if self.grammar_cache is not None:
                self.grammar_cache.reset()
            # TODO(dark): reset the bnf cache
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1254
                if req.rid == recv_req.rid and not req.finished():
1255
                    req.finished_reason = FINISH_ABORT()
1256
                    self.tree_cache.cache_finished_req(req)
1257
1258
1259
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1260
        """In-place update of the weights."""
1261
1262
1263
1264
1265
1266
1267
1268
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1283
1284
1285
1286
1287
1288

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1289
    dp_rank: Optional[int],
1290
    pipe_writer,
1291
):
1292
1293
1294
1295
1296
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1297
    suppress_other_loggers()
1298
1299

    try:
1300
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1301
        pipe_writer.send("ready")
Lianmin Zheng's avatar
Lianmin Zheng committed
1302
1303
1304
1305
        if server_args.enable_overlap_schedule:
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1306
1307
1308
1309
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()