scheduler.py 54.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
"""A scheduler that manages a tensor parallel GPU worker."""

import logging
17
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import threading
19
20
import time
import warnings
Lianmin Zheng's avatar
Lianmin Zheng committed
21
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from concurrent import futures
23
from types import SimpleNamespace
24
from typing import List, Optional
25

26
import torch
27
28
import zmq

29
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
30
from sglang.srt.configs.model_config import ModelConfig
31
32
33
34
35
36
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
    BatchEmbeddingOut,
    BatchTokenIDOut,
37
    CloseSessionReqInput,
38
    FlushCacheReq,
39
40
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
41
42
    OpenSessionReqInput,
    OpenSessionReqOutput,
43
    ProfileReq,
44
45
46
47
48
49
50
51
52
53
54
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
    BaseFinishReason,
    ImageInputs,
    Req,
    ScheduleBatch,
55
    global_server_args_dict,
56
)
57
58
59
60
61
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
62
from sglang.srt.managers.session_controller import Session
63
from sglang.srt.managers.tp_worker import TpModelWorker
64
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
65
66
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
67
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
68
from sglang.srt.model_executor.forward_batch_info import ForwardMode
69
from sglang.srt.server_args import PortArgs, ServerArgs
70
71
72
from sglang.srt.utils import (
    broadcast_pyobj,
    configure_logger,
73
    crash_on_warnings,
74
    get_zmq_socket,
75
    gpu_proc_affinity,
76
77
78
79
    kill_parent_process,
    set_random_seed,
    suppress_other_loggers,
)
80
81
82
83
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

84
# Test retract decode
85
test_retract = os.getenv("SGLANG_TEST_RETRACT", "false").lower() == "true"
86

87
88
89
90
91
92
93
94
95
96

class Scheduler:
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
97
        dp_rank: Optional[int],
98
99
    ):
        # Parse args
100
        self.server_args = server_args
101
102
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
103
        self.schedule_policy = server_args.schedule_policy
Lianmin Zheng's avatar
Lianmin Zheng committed
104
        self.disable_jump_forward = server_args.disable_jump_forward
105
106
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
107
        self.enable_overlap = not server_args.disable_overlap_schedule
108
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
109
        self.enable_metrics = server_args.enable_metrics
110

111
112
113
        # Session info
        self.sessions = {}

114
115
116
        # Init inter-process communication
        context = zmq.Context(2)

Ke Bao's avatar
Ke Bao committed
117
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
118
119
120
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name
            )
121
122
123
            self.send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name
            )
124

125
126
            if server_args.skip_tokenizer_init:
                # Directly send to the tokenizer/api
127
128
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name
129
130
131
                )
            else:
                # Send to the detokenizer
132
133
                self.send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name
134
                )
135
        else:
136
            self.recv_from_tokenizer = None
137
138
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
139
140
141
142

        # Init tokenizer
        self.model_config = ModelConfig(
            server_args.model_path,
143
            trust_remote_code=server_args.trust_remote_code,
144
            context_length=server_args.context_length,
145
146
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
147
        )
148
        self.is_generation = self.model_config.is_generation
149
150
151
152

        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
153
            if self.model_config.is_multimodal:
154
155
156
157
158
159
160
161
162
163
164
165
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
166

167
168
169
170
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
171
172
173

        if self.enable_overlap:
            self.disable_jump_forward = True
174

175
        # Launch a tensor parallel worker
176
        if self.enable_overlap:
177
            TpWorkerClass = TpModelWorkerClient
178
179
        else:
            TpWorkerClass = TpModelWorker
180

181
        self.tp_worker = TpWorkerClass(
182
            server_args=server_args,
183
184
            gpu_id=gpu_id,
            tp_rank=tp_rank,
185
            dp_rank=dp_rank,
186
            nccl_port=port_args.nccl_port,
187
        )
188

189
        # Get token and memory info from the model worker
190
191
192
193
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
194
            self.max_req_len,
195
196
            self.max_req_input_len,
            self.random_seed,
197
            self.device,
198
199
200
201
202
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
203
204
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
205
        global_server_args_dict.update(worker_global_server_args_dict)
206
207
208
209
210
211
212
213
214
215
        set_random_seed(self.random_seed)

        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

216
217
        # Init memory pool and cache
        self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
            )
        else:
            self.tree_cache = RadixCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool=self.token_to_kv_pool,
                disable=server_args.disable_radix_cache,
            )
        self.tree_cache_metrics = {"total": 0, "hit": 0}
234
        self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
235
236
237

        # Init running status
        self.waiting_queue: List[Req] = []
238
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
239
        self.running_batch: Optional[ScheduleBatch] = None
240
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
241
        self.cur_batch: Optional[ScheduleBatch] = None
242
243
        # The current forward batch
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
244
245
        self.forward_ct = 0
        self.forward_ct_decode = 0
246
        self.num_generated_tokens = 0
247
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
248
        self.stream_interval = server_args.stream_interval
249
250
251

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
252
        self.being_chunked_req = None
253
254
255
256
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
257
        # Init the grammar backend for constrained generation
258
        self.grammar_queue: List[Req] = []
259
        if not server_args.skip_tokenizer_init:
Lianmin Zheng's avatar
Lianmin Zheng committed
260
261
262
263
264
265
266
            if server_args.grammar_backend == "outlines":
                from sglang.srt.constrained.outlines_backend import (
                    OutlinesGrammarBackend,
                )

                self.grammar_backend = OutlinesGrammarBackend(
                    self.tokenizer,
267
                    whitespace_pattern=server_args.constrained_json_whitespace_pattern,
Lianmin Zheng's avatar
Lianmin Zheng committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
                    allow_jump_forward=not server_args.disable_jump_forward,
                )
            elif server_args.grammar_backend == "xgrammar":
                from sglang.srt.constrained.xgrammar_backend import (
                    XGrammarGrammarBackend,
                )

                self.grammar_backend = XGrammarGrammarBackend(
                    self.tokenizer, vocab_size=self.model_config.vocab_size
                )
            else:
                raise ValueError(
                    f"Invalid grammar backend: {server_args.grammar_backend}"
                )
        else:
            self.grammar_backend = None
284
285

        # Init new token estimation
286
287
288
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
289
290
291

        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
292
293
            * server_args.schedule_conservativeness,
            1.0,
294
        )
295
296
297
298
299
300
301
302
303
304
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
305
306
307
        # Tells whether the current running batch is full so that we can skip
        # the check of whether to prefill new requests.
        # This is an optimization to reduce the overhead of the prefill check.
308
        self.batch_is_full = False
309

Lianmin Zheng's avatar
Lianmin Zheng committed
310
311
312
313
314
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()

315
        # Init profiler
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
            self.profiler = None
        else:
            self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
            logger.info(
                "Profiling enabled. Traces will be saved to: %s",
                self.torch_profiler_trace_dir,
            )
            self.profiler = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                with_stack=True,
            )
331

332
        # Init metrics stats
333
334
335
336
337
338
339
340
        self.stats = SchedulerStats()
        if self.enable_metrics:
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )
341

Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def watchdog_thread(self):
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if time.time() > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = time.time()
            time.sleep(self.watchdog_timeout / 2)

        kill_parent_process()

359
    @torch.no_grad()
360
    def event_loop_normal(self):
361
        """A normal scheduler loop."""
362
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
363
364
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
365

366
            batch = self.get_next_batch_to_run()
Ke Bao's avatar
Ke Bao committed
367
368
369
            if self.server_args.enable_dp_attention:
                batch = self.prepare_dp_attn_batch(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
370
            self.cur_batch = batch
371
372
373
374

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
375
            else:
376
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
377
                self.check_memory()
378
                self.new_token_ratio = self.init_new_token_ratio
379
380

            self.last_batch = batch
381

382
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
383
    def event_loop_overlap(self):
384
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
387
388
389
390
391
392
393
394
395
396
        result_queue = deque()

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
            if batch:
                result = self.run_batch(batch)
                result_queue.append((batch.copy(), result))

397
398
399
400
401
402
403
404
405
406
                if self.last_batch is None:
                    # A dummy first batch to start the pipeline for overlap scheduler.
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
            if self.last_batch:
                tmp_batch, tmp_result = result_queue.popleft()
409
410
411
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
414
                # Self-check and re-init some states when the server is idle
Lianmin Zheng's avatar
Lianmin Zheng committed
415
                self.check_memory()
416
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
419

            self.last_batch = batch

Ke Bao's avatar
Ke Bao committed
420
421
422
423
424
425
426
427
428
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
        else:
            num_tokens = local_batch.extend_num_tokens

429
430
        local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
        global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
Ke Bao's avatar
Ke Bao committed
431
432
433
        torch.distributed.all_gather_into_tensor(
            global_num_tokens,
            local_num_tokens,
434
            group=self.tp_cpu_group,
Ke Bao's avatar
Ke Bao committed
435
436
437
438
439
440
441
442
        )

        if local_batch is None and global_num_tokens.max().item() > 0:
            local_batch = self.get_idle_batch()

        if local_batch is not None:
            local_batch.global_num_tokens = global_num_tokens.tolist()

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
                forward_mode_state = torch.tensor(
                    (
                        1
                        if local_batch.forward_mode.is_decode()
                        or local_batch.forward_mode.is_idle()
                        else 0
                    ),
                    dtype=torch.int32,
                )
                torch.distributed.all_reduce(
                    forward_mode_state,
                    op=torch.distributed.ReduceOp.MIN,
                    group=self.tp_cpu_group,
                )
                local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1

Ke Bao's avatar
Ke Bao committed
461
462
463
464
465
466
467
468
469
        return local_batch

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
            self.model_config,
470
            self.enable_overlap,
Ke Bao's avatar
Ke Bao committed
471
472
473
474
        )
        idle_batch.prepare_for_idle()
        return idle_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
475
    def recv_requests(self):
Ke Bao's avatar
Ke Bao committed
476
        if self.tp_rank == 0 or self.server_args.enable_dp_attention:
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
480
481
482
483
484
485
486
            recv_reqs = []

            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_req)
        else:
            recv_reqs = None
487

Ke Bao's avatar
Ke Bao committed
488
        if self.tp_size != 1 and not self.server_args.enable_dp_attention:
489
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
490
491
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
492
    def process_input_requests(self, recv_reqs: List):
493
494
495
        for recv_req in recv_reqs:
            if isinstance(recv_req, TokenizedGenerateReqInput):
                self.handle_generate_request(recv_req)
496
            elif isinstance(recv_req, TokenizedEmbeddingReqInput):
497
498
499
500
501
502
503
                self.handle_embedding_request(recv_req)
            elif isinstance(recv_req, FlushCacheReq):
                self.flush_cache()
            elif isinstance(recv_req, AbortReq):
                self.abort_request(recv_req)
            elif isinstance(recv_req, UpdateWeightReqInput):
                success, message = self.update_weights(recv_req)
504
                self.send_to_tokenizer.send_pyobj(
505
506
                    UpdateWeightReqOutput(success, message)
                )
507
508
509
510
511
            elif isinstance(recv_req, ProfileReq):
                if recv_req == ProfileReq.START_PROFILE:
                    self.start_profile()
                else:
                    self.stop_profile()
512
513
514
515
516
            elif isinstance(recv_req, OpenSessionReqInput):
                session_id = self.open_session(recv_req)
                self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
            elif isinstance(recv_req, CloseSessionReqInput):
                self.close_session(recv_req)
517
            elif isinstance(recv_req, GetMemPoolSizeReq):
518
                self.send_to_tokenizer.send_pyobj(
519
520
                    GetMemPoolSizeReqOutput(self.max_total_num_tokens)
                )
521
522
523
524
525
526
527
            else:
                raise ValueError(f"Invalid request: {recv_req}")

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
528
        if recv_req.session_id is None or recv_req.session_id not in self.sessions:
Rin Intachuen's avatar
Rin Intachuen committed
529
530
531
532
533
534
535
            # Check if input_embeds is present and create dummy input_ids
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

536
537
538
539
540
541
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
542
                input_embeds=recv_req.input_embeds,
543
544
545
546
547
548
549
550
551
552
553
            )
            req.tokenizer = self.tokenizer
            if recv_req.session_id is not None:
                req.finished_reason = FINISH_ABORT(
                    f"Invalid request: session id {recv_req.session_id} does not exist"
                )
                self.waiting_queue.append(req)
                return
        else:
            # Handle sessions
            session = self.sessions[recv_req.session_id]
554
            req = session.create_req(recv_req, self.tokenizer)
555
556
557
            if isinstance(req.finished_reason, FINISH_ABORT):
                self.waiting_queue.append(req)
                return
558
559
560
561
562
563

        # Image inputs
        if recv_req.image_inputs is not None:
            req.image_inputs = ImageInputs.from_dict(
                recv_req.image_inputs, self.model_config.vocab_size
            )
564
            req.origin_input_ids = self.pad_input_ids_func(
565
566
567
                req.origin_input_ids_unpadded, req.image_inputs
            )

568
569
570
571
572
573
574
575
576
            if len(req.origin_input_ids) > self.max_req_input_len:
                req.finished_reason = FINISH_ABORT(
                    "Image request length is longer than the KV cache pool size or "
                    "the max context length aborting because you cannot truncate the image embeds"
                )
                req.sampling_params.max_new_tokens = 0
                self.waiting_queue.append(req)
                return

577
578
579
580
581
582
583
584
585
586
        req.return_logprob = recv_req.return_logprob
        req.top_logprobs_num = recv_req.top_logprobs_num
        req.stream = recv_req.stream
        req.logprob_start_len = recv_req.logprob_start_len

        if req.logprob_start_len == -1:
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(recv_req.input_ids) - 1

        # Truncate prompts that are too long
587
        if len(req.origin_input_ids) > self.max_req_input_len:
588
589
590
591
592
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
593

594
595
596
597
598
599
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
600
            self.max_req_len - len(req.origin_input_ids) - 1,
601
602
        )

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
621
622
623
            self.grammar_queue.append(req)
        else:
            self.waiting_queue.append(req)
624
625
626

    def handle_embedding_request(
        self,
627
        recv_req: TokenizedEmbeddingReqInput,
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

        # Truncate prompts that are too long
        if len(req.origin_input_ids) >= self.max_req_input_len:
            logger.warning(
                "Request length is longer than the KV cache pool size or "
                "the max context length. Truncated!!!"
            )
            req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]

        self.waiting_queue.append(req)

647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
        if isinstance(self.tree_cache, RadixCache):
            self.tree_cache_metrics["total"] += (
                adder.log_input_tokens + adder.log_hit_tokens
            ) / 10**9
            self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
            tree_cache_hit_rate = (
                self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
            )
        else:
            tree_cache_hit_rate = 0.0

        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )

        logger.info(
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
            f"#queue-req: {len(self.waiting_queue) + has_inflight}"
        )

        if self.enable_metrics:
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
            self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
            self.stats.cache_hit_rate = tree_cache_hit_rate
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
683
684
685
        num_used = self.max_total_num_tokens - (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
686
687
688
        gen_throughput = self.num_generated_tokens / (
            time.time() - self.last_decode_stats_tic
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
689
        self.num_generated_tokens = 0
690
        self.last_decode_stats_tic = time.time()
691
        num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
Lianmin Zheng's avatar
Lianmin Zheng committed
692
693
        logger.info(
            f"Decode batch. "
694
            f"#running-req: {num_running_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
695
696
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
697
            f"gen throughput (token/s): {gen_throughput:.2f}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
698
699
700
            f"#queue-req: {len(self.waiting_queue)}"
        )

701
702
703
704
705
706
707
708
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = gen_throughput
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
709
710
711
712
713
    def check_memory(self):
        available_size = (
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
        )
        if available_size != self.max_total_num_tokens:
714
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
715
                "KV cache pool leak detected!"
716
                f"{available_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
717
            )
718
719
720
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
721
722

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
723
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
724
                "Memory pool leak detected!"
725
726
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
727
            )
728
729
730
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
731

732
    def get_next_batch_to_run(self):
733
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
734
        if self.last_batch and self.last_batch.forward_mode.is_extend():
735
            if self.being_chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
736
                # Move the chunked request out of the batch
Chayenne's avatar
Chayenne committed
737
                self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
738
                self.tree_cache.cache_unfinished_req(self.being_chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
739
                # Inflight request keeps its rid but will get a new req_pool_idx
740
                self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
741
                self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
742

743
744
745
746
747
            if not self.last_batch.is_empty():
                if self.running_batch is None:
                    self.running_batch = self.last_batch
                else:
                    self.running_batch.merge_batch(self.last_batch)
748

Lianmin Zheng's avatar
Lianmin Zheng committed
749
        # Run prefill first if possible
750
751
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
752
            return new_batch
753

754
        # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
755
        if self.running_batch is None:
756
            return None
Lianmin Zheng's avatar
Lianmin Zheng committed
757
        self.running_batch = self.update_running_batch(self.running_batch)
758
        return self.running_batch
759

Lianmin Zheng's avatar
Lianmin Zheng committed
760
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
761
        # Check if the grammar is ready in the grammar queue
762
        if self.grammar_queue:
763
            self.move_ready_grammar_requests()
764

Lianmin Zheng's avatar
Lianmin Zheng committed
765
766
767
        # Handle the cases where prefill is not allowed
        if (
            self.batch_is_full or len(self.waiting_queue) == 0
768
        ) and self.being_chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
771
        running_bs = len(self.running_batch.reqs) if self.running_batch else 0
772
        if running_bs >= self.max_running_requests:
773
            self.batch_is_full = True
774
775
776
777
778
            return None

        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
779
        # Prefill policy
780
781
782
783
784
785
786
        adder = PrefillAdder(
            self.tree_cache,
            self.running_batch,
            self.new_token_ratio,
            self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
            self.max_prefill_tokens,
            self.chunked_prefill_size,
787
            running_bs if self.is_mixed_chunk else 0,
788
789
        )

790
        has_inflight = self.being_chunked_req is not None
791
        if has_inflight:
792
            self.being_chunked_req.init_next_round_input()
Chayenne's avatar
Chayenne committed
793
            self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
794

Lianmin Zheng's avatar
Lianmin Zheng committed
795
        if self.lora_paths:
796
797
798
799
800
801
            lora_set = (
                set([req.lora_path for req in self.running_batch.reqs])
                if self.running_batch is not None
                else set([])
            )

802
        # Get requests from the waiting queue to a new prefill batch
803
804
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
805
                self.lora_paths
806
807
808
809
810
811
812
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
813
                self.batch_is_full = True
814
815
                break

816
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
817
                self.batch_is_full = True
818
                break
819

820
821
            req.init_next_round_input(None if prefix_computed else self.tree_cache)
            res = adder.add_one_req(req)
822
823
824
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
                    self.batch_is_full = True
825
826
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
827
        # Update waiting queue
828
        can_run_list = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
829
830
831
832
833
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
834

835
        if adder.new_inflight_req is not None:
836
837
            assert self.being_chunked_req is None
            self.being_chunked_req = adder.new_inflight_req
838

839
840
        if self.being_chunked_req:
            self.being_chunked_req.is_being_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
841

842
843
        # Print stats
        if self.tp_rank == 0:
844
            self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
845

Lianmin Zheng's avatar
Lianmin Zheng committed
846
        # Create a new batch
847
848
849
850
851
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool,
            self.tree_cache,
852
            self.model_config,
853
            self.enable_overlap,
854
        )
855
        new_batch.prepare_for_extend()
856

Lianmin Zheng's avatar
Lianmin Zheng committed
857
        # Mixed-style chunked prefill
858
859
860
861
862
863
        if (
            self.is_mixed_chunk
            and self.running_batch is not None
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
864
865
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
866
                self.running_batch.prepare_for_decode()
867
868
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
869
            self.running_batch = None
Lianmin Zheng's avatar
Lianmin Zheng committed
870
871
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
872
873
874

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
875
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
876
        """Update the current running decoding batch."""
877
        global test_retract
Lianmin Zheng's avatar
Lianmin Zheng committed
878
879

        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
880

881
882
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
883
884
            self.batch_is_full = False
            return None
885

Lianmin Zheng's avatar
Lianmin Zheng committed
886
        # Check if decode out of memory
887
        if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10):
Lianmin Zheng's avatar
Lianmin Zheng committed
888
889
890
891
            old_ratio = self.new_token_ratio

            retracted_reqs, new_token_ratio = batch.retract_decode()
            self.new_token_ratio = new_token_ratio
892

Lianmin Zheng's avatar
Lianmin Zheng committed
893
894
895
896
897
898
899
900
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
            self.waiting_queue.extend(retracted_reqs)
        else:
            self.new_token_ratio = max(
901
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
902
903
904
905
                self.min_new_token_ratio,
            )

        # Check for jump-forward
Lianmin Zheng's avatar
Lianmin Zheng committed
906
        if not self.disable_jump_forward:
Lianmin Zheng's avatar
Lianmin Zheng committed
907
908
909
            jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
            self.waiting_queue.extend(jump_forward_reqs)
            if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
910
911
912
913
914
                self.batch_is_full = False
                return None

        if batch.batch_size() < initial_bs:
            self.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
915
916

        # Update batch tensors
917
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
918
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
919
920

    def run_batch(self, batch: ScheduleBatch):
921
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
        self.forward_ct += 1

924
        if self.is_generation:
925
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
926
            if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
927
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
928
                    model_worker_batch
929
                )
Ke Bao's avatar
Ke Bao committed
930
931
932
933
            elif batch.forward_mode.is_idle():
                model_worker_batch = batch.get_model_worker_batch()
                self.tp_worker.forward_batch_idle(model_worker_batch)
                return
Lianmin Zheng's avatar
Lianmin Zheng committed
934
935
            else:
                logits_output = None
936
                if self.skip_tokenizer_init:
937
938
939
                    next_token_ids = torch.full(
                        (batch.batch_size(),), self.tokenizer.eos_token_id
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
940
                else:
941
                    next_token_ids = torch.full((batch.batch_size(),), 0)
942
            batch.output_ids = next_token_ids
943
            ret = logits_output, next_token_ids, model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
944
945
946
947
        else:  # embedding or reward model
            assert batch.extend_num_tokens != 0
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
948
            ret = embeddings, model_worker_batch.bid
949
        return ret
Chayenne's avatar
Chayenne committed
950

Lianmin Zheng's avatar
Lianmin Zheng committed
951
952
953
    def process_batch_result(self, batch: ScheduleBatch, result):
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
954
955
            if batch.is_empty():
                self.running_batch = None
956
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
957
            self.process_batch_result_prefill(batch, result)
958
959
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
960
            torch.cuda.current_stream().synchronize()
961
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
962
963

    def process_batch_result_prefill(self, batch: ScheduleBatch, result):
Lianmin Zheng's avatar
Lianmin Zheng committed
964

Lianmin Zheng's avatar
Lianmin Zheng committed
965
        if self.is_generation:
966
            logits_output, next_token_ids, bid = result
967
968

            if self.enable_overlap:
969
                logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
970
971
972
            else:
                # Move next_token_ids and logprobs to cpu
                if batch.return_logprob:
973
974
                    logits_output.next_token_logprobs = (
                        logits_output.next_token_logprobs[
975
                            torch.arange(len(next_token_ids), device=self.device),
976
977
978
979
980
981
982
983
984
                            next_token_ids,
                        ].tolist()
                    )
                    logits_output.input_token_logprobs = (
                        logits_output.input_token_logprobs.tolist()
                    )
                    logits_output.normalized_prompt_logprobs = (
                        logits_output.normalized_prompt_logprobs.tolist()
                    )
985
                next_token_ids = next_token_ids.tolist()
986
987
988

            # Check finish conditions
            logprob_pt = 0
989
            for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
990
991
992
                if req.is_retracted:
                    continue

Lianmin Zheng's avatar
Lianmin Zheng committed
993
                if self.is_mixed_chunk and self.enable_overlap and req.finished():
994
995
996
997
                    # Free the one delayed token for the mixed decode batch
                    j = len(batch.out_cache_loc) - len(batch.reqs) + i
                    self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
998

999
                if req.is_being_chunked <= 0:
1000
                    req.completion_tokens_wo_jump_forward += 1
1001
                    req.output_ids.append(next_token_id)
1002
1003
                    req.check_finished()

1004
                    if req.finished():
1005
                        self.tree_cache.cache_finished_req(req)
1006
1007
1008
                    elif not batch.decoding_reqs or req not in batch.decoding_reqs:
                        self.tree_cache.cache_unfinished_req(req)

1009
1010
1011
1012
                    if req.return_logprob:
                        logprob_pt += self.add_logprob_return_values(
                            i, req, logprob_pt, next_token_ids, logits_output
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1013
1014
1015

                    if req.grammar is not None:
                        req.grammar.accept_token(next_token_id)
1016
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
                    # Inflight reqs' prefill is not finished
1018
1019
                    req.is_being_chunked -= 1

1020
1021
            if batch.next_batch_sampling_info:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
1022
                torch.cuda.current_stream().synchronize()
1023
1024
                batch.next_batch_sampling_info.sampling_info_done.set()

Lianmin Zheng's avatar
Lianmin Zheng committed
1025
        else:  # embedding or reward model
1026
1027
            embeddings, bid = result
            embeddings = embeddings.tolist()
1028
1029
1030

            # Check finish conditions
            for i, req in enumerate(batch.reqs):
1031
1032
1033
                if req.is_retracted:
                    continue

1034
                req.embedding = embeddings[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
1035
1036
                if req.is_being_chunked <= 0:
                    # Dummy output token for embedding models
1037
1038
1039
                    req.output_ids.append(0)
                    req.check_finished()

Lianmin Zheng's avatar
Lianmin Zheng committed
1040
1041
1042
1043
                    if req.finished():
                        self.tree_cache.cache_finished_req(req)
                    else:
                        self.tree_cache.cache_unfinished_req(req)
1044
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1045
1046
                    # Inflight reqs' prefill is not finished
                    req.is_being_chunked -= 1
1047

1048
        self.stream_output(batch.reqs)
1049

Lianmin Zheng's avatar
Lianmin Zheng committed
1050
    def process_batch_result_decode(self, batch: ScheduleBatch, result):
1051
        logits_output, next_token_ids, bid = result
Lianmin Zheng's avatar
Lianmin Zheng committed
1052
1053
        self.num_generated_tokens += len(batch.reqs)

1054
        if self.enable_overlap:
1055
            logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1056
            next_token_logprobs = logits_output.next_token_logprobs
1057
1058
1059
1060
1061
1062
1063
1064
        else:
            # Move next_token_ids and logprobs to cpu
            if batch.return_logprob:
                next_token_logprobs = logits_output.next_token_logprobs[
                    torch.arange(len(next_token_ids), device=self.device),
                    next_token_ids,
                ].tolist()
            next_token_ids = next_token_ids.tolist()
Lianmin Zheng's avatar
Lianmin Zheng committed
1065

1066
1067
        self.token_to_kv_pool.free_group_begin()

Lianmin Zheng's avatar
Lianmin Zheng committed
1068
1069
        # Check finish condition
        for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1070
1071
1072
            if req.is_retracted:
                continue

1073
            if self.enable_overlap and req.finished():
1074
                # Free the one delayed token
1075
                self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
Lianmin Zheng's avatar
Lianmin Zheng committed
1076
1077
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
1078
1079
1080
1081
1082
            req.completion_tokens_wo_jump_forward += 1
            req.output_ids.append(next_token_id)
            req.check_finished()

            if req.finished():
1083
                self.tree_cache.cache_finished_req(req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1084
1085
1086
1087
1088
1089
1090
1091

            if req.return_logprob:
                req.output_token_logprobs.append(
                    (next_token_logprobs[i], next_token_id)
                )
                if req.top_logprobs_num > 0:
                    req.output_top_logprobs.append(logits_output.output_top_logprobs[i])

Lianmin Zheng's avatar
Lianmin Zheng committed
1092
1093
1094
            if req.grammar is not None:
                req.grammar.accept_token(next_token_id)

1095
1096
        if batch.next_batch_sampling_info:
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1097
            torch.cuda.current_stream().synchronize()
1098
1099
            batch.next_batch_sampling_info.sampling_info_done.set()

1100
        self.stream_output(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1101

1102
1103
        self.token_to_kv_pool.free_group_end()

Lianmin Zheng's avatar
Lianmin Zheng committed
1104
        self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
Chayenne's avatar
Chayenne committed
1105
1106
1107
1108
        if (
            self.tp_rank == 0
            and self.forward_ct_decode % self.server_args.decode_log_interval == 0
        ):
1109
            self.log_decode_stats()
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
    def add_logprob_return_values(
        self,
        i: int,
        req: Req,
        pt: int,
        next_token_ids: List[int],
        output: LogitsProcessorOutput,
    ):
        """Attach logprobs to the return values."""
        req.output_token_logprobs.append(
            (output.next_token_logprobs[i], next_token_ids[i])
        )

        # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
        num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len

        if req.normalized_prompt_logprob is None:
            req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]

        if req.input_token_logprobs is None:
            input_token_logprobs = output.input_token_logprobs[
                pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
            ]
            input_token_ids = req.fill_ids[
                len(req.fill_ids)
                - num_input_logprobs
                + 1 : len(req.fill_ids)
                - req.last_update_decode_tokens
            ]
            req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))

            if (
                req.logprob_start_len == 0
            ):  # The first token does not have logprob, pad it.
                req.input_token_logprobs = [
                    (None, req.fill_ids[0])
                ] + req.input_token_logprobs

        if req.last_update_decode_tokens != 0:
            # Some decode tokens are re-computed in an extend batch
            req.output_token_logprobs.extend(
                list(
                    zip(
                        output.input_token_logprobs[
                            pt
                            + num_input_logprobs
                            - 1
                            - req.last_update_decode_tokens : pt
                            + num_input_logprobs
                            - 1
                        ],
                        req.fill_ids[
                            len(req.fill_ids)
                            - req.last_update_decode_tokens : len(req.fill_ids)
                        ],
                    )
                )
            )

        if req.top_logprobs_num > 0:
            if req.input_top_logprobs is None:
                req.input_top_logprobs = output.input_top_logprobs[i]
                if req.logprob_start_len == 0:
                    req.input_top_logprobs = [None] + req.input_top_logprobs

            if req.last_update_decode_tokens != 0:
                req.output_top_logprobs.extend(
                    output.input_top_logprobs[i][-req.last_update_decode_tokens :]
                )
            req.output_top_logprobs.append(output.output_top_logprobs[i])

        return num_input_logprobs

1184
    def stream_output(self, reqs: List[Req]):
1185
        """Stream the output to detokenizer."""
1186
        output_rids = []
1187
        output_meta_info: List[dict] = []
1188
1189
1190
1191
1192
1193
        output_finished_reason: List[BaseFinishReason] = []
        if self.is_generation:
            output_vids = []
            decoded_texts = []
            output_read_ids = []
            output_read_offsets = []
1194
            output_ids = []
1195
1196
            output_skip_special_tokens = []
            output_spaces_between_special_tokens = []
1197
            output_no_stop_trim = []
Lianmin Zheng's avatar
Lianmin Zheng committed
1198
        else:  # embedding or reward model
1199
1200
            output_embeddings = []

Lianmin Zheng's avatar
Lianmin Zheng committed
1201
        is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1202

1203
        for req in reqs:
1204
            # TODO(lianmin): revisit this for overlap + retract + stream
1205
            if req.finished() or (
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
                req.stream and (is_stream_iter or len(req.output_ids) == 1)
1207
1208
1209
1210
1211
1212
1213
1214
1215
            ):
                output_rids.append(req.rid)
                output_finished_reason.append(req.finished_reason)
                if self.is_generation:
                    output_vids.append(req.vid)
                    decoded_texts.append(req.decoded_text)
                    read_ids, read_offset = req.init_incremental_detokenize()
                    output_read_ids.append(read_ids)
                    output_read_offsets.append(read_offset)
1216
1217
                    if self.skip_tokenizer_init:
                        output_ids.append(req.output_ids)
1218
1219
1220
1221
1222
1223
                    output_skip_special_tokens.append(
                        req.sampling_params.skip_special_tokens
                    )
                    output_spaces_between_special_tokens.append(
                        req.sampling_params.spaces_between_special_tokens
                    )
1224
                    output_no_stop_trim.append(req.sampling_params.no_stop_trim)
1225
1226
1227
1228
1229

                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                        "completion_tokens": len(req.output_ids),
                        "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
1230
                        "cached_tokens": req.cached_tokens,
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
                        "finish_reason": (
                            req.finished_reason.to_json()
                            if req.finished_reason is not None
                            else None
                        ),
                    }
                    if req.return_logprob:
                        (
                            meta_info["input_token_logprobs"],
                            meta_info["output_token_logprobs"],
                            meta_info["input_top_logprobs"],
                            meta_info["output_top_logprobs"],
                            meta_info["normalized_prompt_logprob"],
                        ) = (
                            req.input_token_logprobs,
                            req.output_token_logprobs,
                            req.input_top_logprobs,
                            req.output_top_logprobs,
                            req.normalized_prompt_logprob,
                        )
                    output_meta_info.append(meta_info)
Lianmin Zheng's avatar
Lianmin Zheng committed
1252
                else:  # embedding or reward model
1253
1254
1255
1256
1257
1258
1259
1260
1261
                    output_embeddings.append(req.embedding)
                    meta_info = {
                        "prompt_tokens": len(req.origin_input_ids),
                    }
                    output_meta_info.append(meta_info)

        # Send to detokenizer
        if output_rids:
            if self.is_generation:
1262
                self.send_to_detokenizer.send_pyobj(
1263
1264
1265
1266
1267
1268
                    BatchTokenIDOut(
                        output_rids,
                        output_vids,
                        decoded_texts,
                        output_read_ids,
                        output_read_offsets,
1269
                        output_ids,
1270
1271
1272
1273
                        output_skip_special_tokens,
                        output_spaces_between_special_tokens,
                        output_meta_info,
                        output_finished_reason,
1274
                        output_no_stop_trim,
1275
1276
                    )
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1277
            else:  # embedding or reward model
1278
                self.send_to_detokenizer.send_pyobj(
1279
1280
1281
1282
1283
1284
1285
1286
                    BatchEmbeddingOut(
                        output_rids,
                        output_embeddings,
                        output_meta_info,
                        output_finished_reason,
                    )
                )

1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

        if self.tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max

        self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1311
    def flush_cache(self):
1312
        """Flush the memory pool and cache."""
1313
1314
1315
1316
1317
        if len(self.waiting_queue) == 0 and (
            self.running_batch is None or len(self.running_batch.reqs) == 0
        ):
            self.tree_cache.reset()
            self.tree_cache_metrics = {"total": 0, "hit": 0}
1318
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1319
                self.grammar_backend.reset()
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
            self.req_to_token_pool.clear()
            self.token_to_kv_pool.clear()
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
                f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
            )
            if_success = False
        return if_success

    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
        to_del = None
        for i, req in enumerate(self.waiting_queue):
            if req.rid == recv_req.rid:
                to_del = i
                break

        if to_del is not None:
            del self.waiting_queue[to_del]

        # Delete requests in the running batch
        if self.running_batch:
            for req in self.running_batch.reqs:
1348
                if req.rid == recv_req.rid and not req.finished():
1349
                    req.finished_reason = FINISH_ABORT()
1350
                    self.tree_cache.cache_finished_req(req)
1351
1352
1353
                    break

    def update_weights(self, recv_req: UpdateWeightReqInput):
1354
        """In-place update of the weights."""
1355
1356
1357
1358
1359
1360
1361
1362
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    def start_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.start()

    def stop_profile(self) -> None:
        if self.profiler is None:
            raise RuntimeError("Profiler is not enabled.")
        self.profiler.stop()
        self.profiler.export_chrome_trace(
            self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
        )
        logger.info("Profiler is done")

1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
    def open_session(self, recv_req: OpenSessionReqInput) -> str:
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
        return session_id

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1396
1397
1398
1399
1400
1401

def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1402
    dp_rank: Optional[int],
1403
    pipe_writer,
1404
):
1405
1406
1407
    # set cpu affinity to this gpu process
    gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1408
    # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
Lianmin Zheng's avatar
Lianmin Zheng committed
1409
    if dp_rank is None and "DP_RANK" in os.environ:
Byron Hsu's avatar
Byron Hsu committed
1410
        dp_rank = int(os.environ["DP_RANK"])
1411

1412
1413
1414
1415
1416
    if dp_rank is None:
        configure_logger(server_args, prefix=f" TP{tp_rank}")
    else:
        configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

1417
    suppress_other_loggers()
1418
1419

    try:
1420
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1421
1422
1423
        pipe_writer.send(
            {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
        )
1424
        if scheduler.enable_overlap:
Lianmin Zheng's avatar
Lianmin Zheng committed
1425
1426
1427
            scheduler.event_loop_overlap()
        else:
            scheduler.event_loop_normal()
1428
1429
1430
1431
    except Exception:
        msg = get_exception_traceback()
        logger.error(msg)
        kill_parent_process()