scheduler.py 45.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A scheduler that manages a tensor parallel GPU worker."""

import logging
19
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import threading
21
22
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
23
from collections import deque
24
from types import SimpleNamespace
25
from typing import List, Optional
26

27
import torch
28
29
import zmq

30
31
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
32
from sglang.srt.constrained.grammar import GrammarBackend
33
34
35
36
37
38
39
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
    FlushCacheReq,
40
41
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
42
    ProfileReq,
43
44
45
46
47
48
49
50
51
52
53
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
54
    global_server_args_dict,
55
)
56
57
58
59
60
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
61
from sglang.srt.managers.tp_worker import TpModelWorker
62
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
64
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
65
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
66
from sglang.srt.server_args import PortArgs, ServerArgs
67
68
69
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
70
    get_zmq_socket,
71
72
73
74
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
75
76
77
78
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

79

80
81
82
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"

83
84
85
# Test retract decode
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false") == "true"

86
87
88
89
90
91
92
93
94
95

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
96
        dp_rank: Optional[int],
97
98
    ):
        # Parse args
99
        self.server_args = server_args
100
101
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
102
103
104
105
        self.schedule_policy = server_args.schedule_policy
        self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
106
        self.enable_overlap = server_args.enable_overlap_schedule
107
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
        self.enable_metrics = server_args.enable_metrics
109
110
111
112
113

        # Init inter-process communication
        context = zmq.Context(2)

        if self.tp_rank == 0:
114
115
116
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
117

118
119
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
120
121
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
122
123
124
                )
            else:
                # Send to the detokenizer
125
126
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
127
                )
128
        else:
129
130
            self.recv_from_tokenizer = None
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
131
132
133
134

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
135
            trust_remote_code=server_args.trust_remote_code,
136
            context_length=server_args.context_length,
137
138
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
139
        )
140
        self.is_generation = self.model_config.is_generation
141
142
143
144

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
145
            if self.model_config.is_multimodal:
146
147
148
149
150
151
152
153
154
155
156
157
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
158

159
        # Launch a tensor parallel worker
160
        if self.enable_overlap:
161
            TpWorkerClass = TpModelWorkerClient
162
163
        else:
            TpWorkerClass = TpModelWorker
164

165
        self.tp_worker = TpWorkerClass(
166
            server_args=server_args,
167
168
            gpu_id=gpu_id,
            tp_rank=tp_rank,
169
            dp_rank=dp_rank,
170
            nccl_port=port_args.nccl_port,
171
        )
172

173
        # Get token and memory info from the model worker
174
175
176
177
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
178
            self.max_req_len,
179
180
            self.max_req_input_len,
            self.random_seed,
181
            self.device,
182
183
184
185
186
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
187
188
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
189
        global_server_args_dict.update(worker_global_server_args_dict)
190
191
192
193
194
195
196
197
198
199
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

200
201
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
218
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
219
220
221

        # Init running status
        self.waiting_queue: List[Req] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
222
        self.running_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
223
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
        self.forward_ct = 0
        self.forward_ct_decode = 0
226
        self.num_generated_tokens = 0
227
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
228
        self.stream_interval = server_args.stream_interval
229
230
231

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
232
        self.being_chunked_req = None
233
234
235
236
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

237
        # Init the grammar cache for constrained generation
238
        self.grammar_cache = None
239
        self.grammar_queue: List[Req] = []
240

241
        if not server_args.skip_tokenizer_init:
242
            self.grammar_cache = GrammarBackend(
243
244
245
246
247
248
                server_args.tokenizer_path,
                {
                    "tokenizer_mode": server_args.tokenizer_mode,
                    "trust_remote_code": server_args.trust_remote_code,
                },
                skip_tokenizer_init=server_args.skip_tokenizer_init,
249
250
251
                whitespace_patterns=server_args.constrained_json_whitespace_pattern,
                backend=server_args.grammar_backend,
                allow_jump=not server_args.disable_regex_jump_forward,
252
            )
253
254

        # Init new token estimation
255
256
257
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
258
259
260

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
261
262
            * server_args.schedule_conservativeness,
            1.0,
263
        )
264
265
266
267
268
269
270
271
272
273
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

274
        self.batch_is_full = False
275

Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
278
279
280
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()

281
        # Init profiler
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
297

298
        # Init metrics stats
299
300
301
302
303
304
305
306
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
307

Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    def watchdog_thread(self):
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

        kill_parent_process()

Lianmin Zheng's avatar
Lianmin Zheng committed
325
    @torch.inference_mode()
326
    def event_loop_normal(self):
327
        """A normal blocking scheduler loop."""
328
329
        self.last_batch = None

330
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
331
332
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
333

334
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
335
            self.cur_batch = batch
336
337
338
339

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
340

341
342
343
344
345
346
347
348
349
350
                # Decode multiple steps to reduce the overhead
                if batch.forward_mode.is_decode():
                    for _ in range(self.server_args.num_continuous_decode_steps - 1):
                        if not self.running_batch:
                            break
                        self.update_running_batch()
                        if not self.running_batch:
                            break
                        result = self.run_batch(batch)
                        self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
            else:
                self.check_memory()
353
                self.new_token_ratio = self.init_new_token_ratio
354
355

            self.last_batch = batch
356

Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
    @torch.inference_mode()
    def event_loop_overlap(self):
359
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        result_queue = deque()

        self.last_batch = None
        self.running_batch = None

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
                self.check_memory()
380
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383

            self.last_batch = batch

Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
386
387
388
389
390
391
392
393
394
395
    def recv_requests(self):
        if self.tp_rank == 0:
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
396

397
398
        if self.tp_size != 1:
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
399
400
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
401
    def process_input_requests(self, recv_reqs: List):
402
403
404
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
405
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
406
407
408
409
410
411
412
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
413
414
415
                self.send_to_detokenizer.send_pyobj(
                    UpdateWeightReqOutput(success, message)
                )
416
417
418
419
420
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
421
422
423
424
            elif isinstance(recv_req, GetMemPoolSizeReq):
                self.send_to_detokenizer.send_pyobj(
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
            lora_path=recv_req.lora_path,
        )
        req.tokenizer = self.tokenizer

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
446
            req.origin_input_ids = self.pad_input_ids_func(
447
448
449
450
451
452
453
454
455
456
457
458
                req.origin_input_ids_unpadded, req.image_inputs
            )

        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

459
        # Init grammar cache for this request
460
461
462
463
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
464
            assert self.grammar_cache is not None
465
            if req.sampling_params.json_schema is not None:
466
467
468
                req.grammar = self.grammar_cache.query(
                    ("json", req.sampling_params.json_schema),
                    self.model_config.vocab_size,
469
470
                )
            elif req.sampling_params.regex is not None:
471
472
                req.grammar = self.grammar_cache.query(
                    ("regex", req.sampling_params.regex), self.model_config.vocab_size
473
474
475
                )

        # Truncate prompts that are too long
476
        if len(req.origin_input_ids) > self.max_req_input_len:
477
478
479
480
481
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
482

483
484
485
486
487
488
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
489
            self.max_req_len - len(req.origin_input_ids) - 1,
490
491
        )

492
493
494
495
        if req.grammar is not None:
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
496
497
498

    def handle_embedding_request(
        self,
499
        recv_req: TokenizedEmbeddingReqInput,
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
        if isinstance(self.tree_cache, RadixCache):
            self.tree_cache_metrics["total"] += (
                adder.log_input_tokens + adder.log_hit_tokens
            ) / 10**9
            self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
        else:
            tree_cache_hit_rate = 0.0

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
            f"#queue-req: {len(self.waiting_queue) + has_inflight}"
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556
557
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
558
559
560
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
561
        self.num_generated_tokens = 0
562
        self.last_decode_stats_tic = time.time()
563
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
564
565
        logger.info(
            f"Decode batch. "
566
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
567
568
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
569
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
570
571
572
            f"#queue-req: {len(self.waiting_queue)}"
        )

573
574
575
576
577
578
579
580
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
            warnings.warn(
                "Warning: "
                f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
                "KV cache pool leak detected!"
            )
            exit(1) if crash_on_warning else None

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
            warnings.warn(
                "Warning: "
                f"available req slots={len(self.req_to_token_pool.free_slots)}, "
                f"total slots={self.req_to_token_pool.size}\n"
                "Memory pool leak detected!"
            )
            exit(1) if crash_on_warning else None

602
    def get_next_batch_to_run(self):
603
        # Merge the prefill batch into the running batch
604
605
606
607
608
        if (
            self.last_batch
            and not self.last_batch.forward_mode.is_decode()
            and not self.last_batch.is_empty()
        ):
609
            if self.being_chunked_req:
Chayenne's avatar
Chayenne committed
610
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
611
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
612
                # Inflight request keeps its rid but will get a new req_pool_idx.
613
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
614
615
616
617
618
619
                self.batch_is_full = False
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
620
621

        # Prefill first
622
623
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
624
            return new_batch
625

626
627
628
629
630
631
632
633
634
635
636
637
638
        # Check memory
        if self.running_batch is None:
            return

        # Run decode
        before_bs = self.running_batch.batch_size()
        self.update_running_batch()
        if not self.running_batch:
            self.batch_is_full = False
            return None
        if before_bs != self.running_batch.batch_size():
            self.batch_is_full = False
        return self.running_batch
639

Lianmin Zheng's avatar
Lianmin Zheng committed
640
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
641
642
643
644
645
646
647
648
649
650
651
        # Check if the grammar queue is ready
        if self.grammar_queue:
            new_grammar_queue = []
            for req in self.grammar_queue:
                if req.grammar.done():
                    req.grammar = req.grammar.result()
                    self.waiting_queue.append(req)
                else:
                    new_grammar_queue.append(req)
            self.grammar_queue = new_grammar_queue

Lianmin Zheng's avatar
Lianmin Zheng committed
652
653
654
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
655
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
656
657
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
658
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
659
        if running_bs >= self.max_running_requests:
660
            self.batch_is_full = True
661
662
663
664
665
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
666
        # Prefill policy
667
668
669
670
671
672
673
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
674
            running_bs if self.is_mixed_chunk else 0,
675
676
        )

677
        has_inflight = self.being_chunked_req is not None
678
        if has_inflight:
679
            self.being_chunked_req.init_next_round_input()
Chayenne's avatar
Chayenne committed
680
            self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
681

Lianmin Zheng's avatar
Lianmin Zheng committed
682
        if self.lora_paths:
683
684
685
686
687
688
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

689
        # Get requests from the waiting queue to a new prefill batch
690
691
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
692
                self.lora_paths
693
694
695
696
697
698
699
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
700
                self.batch_is_full = True
701
702
                break

703
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
704
                self.batch_is_full = True
705
                break
706

707
708
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
709
710
711
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
712
713
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
714
        # Update waiting queue
715
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
716
717
718
719
720
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
721

722
        if adder.new_inflight_req is not None:
723
724
            assert self.being_chunked_req is None
            self.being_chunked_req = adder.new_inflight_req
725

726
727
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
728

729
730
        # Print stats
        if self.tp_rank == 0:
731
            self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
732

Lianmin Zheng's avatar
Lianmin Zheng committed
733
        # Create a new batch
734
735
736
737
738
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
739
            self.model_config,
740
        )
741
        new_batch.prepare_for_extend()
742

Lianmin Zheng's avatar
Lianmin Zheng committed
743
        # Mixed-style chunked prefill
744
        if self.is_mixed_chunk and self.running_batch is not None:
745
746
747
748
749
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
                self.running_batch.prepare_for_decode(self.enable_overlap)
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
750
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755

        return new_batch

756
    def update_running_batch(self):
757
        """Update the current running decoding batch."""
758
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
759
760
        batch = self.running_batch

761
762
763
764
765
        batch.filter_batch()
        if batch.is_empty():
            self.running_batch = None
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
766
        # Check if decode out of memory
767
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
768
769
770
771
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
772

Lianmin Zheng's avatar
Lianmin Zheng committed
773
774
775
776
777
778
779
780
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
781
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
782
783
784
785
786
787
788
789
                self.min_new_token_ratio,
            )

        # Check for jump-forward
        if not self.disable_regex_jump_forward:
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
790
791
                self.running_batch = None
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
792
793

        # Update batch tensors
794
        batch.prepare_for_decode(self.enable_overlap)
Lianmin Zheng's avatar
Lianmin Zheng committed
795
796

    def run_batch(self, batch: ScheduleBatch):
797
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
798
799
        self.forward_ct += 1

800
        if self.is_generation:
801
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
802
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
803
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
804
                    model_worker_batch
805
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
806
807
            else:
                logits_output = None
808
                if self.skip_tokenizer_init:
809
810
811
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
812
                else:
813
                    next_token_ids = torch.full((batch.batch_size(),), 0)
814
            batch.output_ids = next_token_ids
815
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
816
817
818
819
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
820
            ret = embeddings, model_worker_batch.bid
821
        return ret
Chayenne's avatar
Chayenne committed
822

Lianmin Zheng's avatar
Lianmin Zheng committed
823
824
825
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
826
827
            if batch.is_empty():
                self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
828
829
830
831
        else:
            self.process_batch_result_prefill(batch, result)

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
832

Lianmin Zheng's avatar
Lianmin Zheng committed
833
        if self.is_generation:
834
            logits_output, next_token_ids, bid = result
835
836
837
838
839
840

            if self.enable_overlap:
                logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
841
842
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
843
                            torch.arange(len(next_token_ids), device=self.device),
844
845
846
847
848
849
850
851
852
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
853
                next_token_ids = next_token_ids.tolist()
854
855
856
857

            # Check finish conditions
            logprob_pt = 0
            for i, req in enumerate(batch.reqs):
858
859
860
861
                if req.is_retracted:
                    continue

                if req.is_being_chunked <= 0:
862
                    # Inflight reqs' prefill is not finished
863
864
865
866
                    req.completion_tokens_wo_jump_forward += 1
                    req.output_ids.append(next_token_ids[i])
                    req.check_finished()

867
                    if req.finished():
868
                        self.tree_cache.cache_finished_req(req)
869
870
871
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

872
873
                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_ids[i])
874
875
876
877
878

                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
879
880
881
                else:
                    req.is_being_chunked -= 1

Lianmin Zheng's avatar
Lianmin Zheng committed
882
        else:  # embedding or reward model
883
884
            embeddings, bid = result
            embeddings = embeddings.tolist()
885
886
887

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
888
889
890
                if req.is_retracted:
                    continue

891
                req.embedding = embeddings[i]
892
893
                if req.is_being_chunked > 0:
                    req.is_being_chunked -= 1
894
895
                else:
                    # Inflight reqs' prefill is not finished
896
897
898
899
900
                    # dummy output token for embedding models
                    req.output_ids.append(0)
                    req.check_finished()

                if req.finished():
901
                    self.tree_cache.cache_finished_req(req)
902
903
904
                else:
                    self.tree_cache.cache_unfinished_req(req)

905
        self.stream_output(batch.reqs)
906

Lianmin Zheng's avatar
Lianmin Zheng committed
907
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
908
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
909
910
        self.num_generated_tokens += len(batch.reqs)

911
912
        if self.enable_overlap:
            logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
913
            next_token_logprobs = logits_output.next_token_logprobs
914
915
916
917
918
919
920
921
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
922

923
924
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
925
926
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
927
928
929
            if req.is_retracted:
                continue

Chayenne's avatar
Chayenne committed
930
            if self.server_args.enable_overlap_schedule and (req.finished()):
931
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
932
933
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
934
935
936
937
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

938
939
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
940
941

            if req.finished():
942
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
943
944
945
946
947
948
949
950

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

951
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
952

953
954
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
955
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
956
957
958
959
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
960
            self.log_decode_stats()
961

962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

1035
    def stream_output(self, reqs: List[Req]):
1036
        """Stream the output to detokenizer."""
1037
        output_rids = []
1038
        output_meta_info: List[dict] = []
1039
1040
1041
1042
1043
1044
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1045
            output_ids = []
1046
1047
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1048
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1049
        else:  # embedding or reward model
1050
1051
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1052
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1053

1054
        for req in reqs:
1055
            # TODO(lianmin): revisit this for overlap + retract + stream
1056
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1057
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1058
1059
1060
1061
1062
1063
1064
1065
1066
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1067
1068
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1069
1070
1071
1072
1073
1074
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1075
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1076
1077
1078
1079
1080

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1081
                        "cached_tokens": req.cached_tokens,
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1103
                else:  # embedding or reward model
1104
1105
1106
1107
1108
1109
1110
1111
1112
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1113
                self.send_to_detokenizer.send_pyobj(
1114
1115
1116
1117
1118
1119
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1120
                        output_ids,
1121
1122
1123
1124
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1125
                        output_no_stop_trim,
1126
1127
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
            else:  # embedding or reward model
1129
                self.send_to_detokenizer.send_pyobj(
1130
1131
1132
1133
1134
1135
1136
1137
1138
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

    def flush_cache(self):
1139
        """Flush the memory pool and cache."""
1140
1141
1142
1143
1144
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1145
1146
1147
            if self.grammar_cache is not None:
                self.grammar_cache.reset()
            # TODO(dark): reset the bnf cache
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1176
                if req.rid == recv_req.rid and not req.finished():
1177
                    req.finished_reason = FINISH_ABORT()
1178
                    self.tree_cache.cache_finished_req(req)
1179
1180
1181
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1182
        """In-place update of the weights."""
1183
1184
1185
1186
1187
1188
1189
1190
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1205
1206
1207
1208
1209
1210

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1211
    dp_rank: Optional[int],
1212
    pipe_writer,
1213
):
1214
1215
1216
1217
1218
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1219
    suppress_other_loggers()
1220
1221

    try:
1222
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1223
        pipe_writer.send("ready")
Lianmin Zheng's avatar
Lianmin Zheng committed
1224
1225
1226
1227
        if server_args.enable_overlap_schedule:
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1228
1229
1230
1231
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()